Skip to content

Commit

Permalink
Generic OpenAI API provider that can be configured via command line
Browse files Browse the repository at this point in the history
Part of #111
  • Loading branch information
bauersimon committed May 15, 2024
1 parent 0d5b237 commit ab3da39
Show file tree
Hide file tree
Showing 3 changed files with 180 additions and 1 deletion.
36 changes: 35 additions & 1 deletion cmd/eval-dev-quality/cmd/evaluate.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@ import (
_ "github.com/symflower/eval-dev-quality/language/java" // Register language.
"github.com/symflower/eval-dev-quality/log"
"github.com/symflower/eval-dev-quality/model"
"github.com/symflower/eval-dev-quality/model/llm"
"github.com/symflower/eval-dev-quality/provider"
_ "github.com/symflower/eval-dev-quality/provider/ollama" // Register provider.
_ "github.com/symflower/eval-dev-quality/provider/ollama" // Register provider.
openaiapi "github.com/symflower/eval-dev-quality/provider/openai-api"
_ "github.com/symflower/eval-dev-quality/provider/openrouter" // Register provider.
_ "github.com/symflower/eval-dev-quality/provider/symflower" // Register provider.
"github.com/symflower/eval-dev-quality/tools"
Expand Down Expand Up @@ -51,6 +53,8 @@ type Evaluate struct {

// ProviderTokens holds all API tokens for the providers.
ProviderTokens map[string]string `long:"tokens" description:"API tokens for model providers (of the form '$provider:$token'). When using the environment variable, separate multiple definitions with ','." env:"PROVIDER_TOKEN" env-delim:","`
// ProviderUrls holds all custom inference endpoint urls for the providers.
ProviderUrls map[string]string `long:"urls" description:"Custom OpenAI API compatible inference endpoints (of the form '$provider:$url,...'). Use '$provider=custom-$name' to manually register a custom OpenAI API endpoint provider. Note that the models of a custom OpenAI API endpoint provider must be declared explicitly using the '--model' option. When using the environment variable, separate multiple definitions with ','." env:"PROVIDER_URL" env-delim:","`

// logger holds the logger of the command.
logger *log.Logger
Expand Down Expand Up @@ -99,6 +103,36 @@ func (command *Evaluate) Execute(args []string) (err error) {
}
}

// Register custom OpenAI API providers and models.
{
customProviders := map[string]*openaiapi.Provider{}
for providerID, providerURL := range command.ProviderUrls {
if !strings.HasPrefix(providerID, "custom-") {
continue
}

p := openaiapi.NewProvider(providerID, providerURL)
provider.Register(p)
customProviders[providerID] = p
}
for _, model := range command.Models {
if !strings.HasPrefix(model, "custom-") {
continue
}

providerID, _, ok := strings.Cut(model, provider.ProviderModelSeparator)
if !ok {
log.Panicf("ERROR: cannot split %q into provider and model name by %q", model, provider.ProviderModelSeparator)
}
modelProvider, ok := customProviders[providerID]
if !ok {
log.Panicf("ERROR: unknown custom provider %q for model %q", providerID, model)
}

modelProvider.AddModel(llm.NewModel(modelProvider, model))
}
}

// Gather languages.
languagesSelected := map[string]language.Language{}
{
Expand Down
49 changes: 49 additions & 0 deletions cmd/eval-dev-quality/cmd/evaluate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package cmd
import (
"errors"
"fmt"
"net/url"
"os"
"path/filepath"
"regexp"
Expand Down Expand Up @@ -541,6 +542,54 @@ func TestEvaluateExecute(t *testing.T) {
})
}
})
t.Run("OpenAI API", func(t *testing.T) {
if !osutil.IsLinux() {
t.Skipf("Installation of Ollama is not supported on this OS")
}

{
var shutdown func() (err error)
defer func() {
if shutdown != nil {
require.NoError(t, shutdown())
}
}()
ollamaOpenAIAPIUrl, err := url.JoinPath(tools.OllamaURL, "v1")
require.NoError(t, err)
validate(t, &testCase{
Name: "Ollama",

Before: func(t *testing.T, logger *log.Logger, resultPath string) {
var err error
shutdown, err = tools.OllamaStart(logger, tools.OllamaPath, tools.OllamaURL)
require.NoError(t, err)

require.NoError(t, tools.OllamaPull(logger, tools.OllamaPath, tools.OllamaURL, "qwen:0.5b"))
},

Arguments: []string{
"--language", "golang",
"--urls", "custom-ollama:" + ollamaOpenAIAPIUrl,
"--model", "custom-ollama/qwen:0.5b",
"--repository", filepath.Join("golang", "plain"),
},

ExpectedResultFiles: map[string]func(t *testing.T, filePath string, data string){
"categories.svg": nil,
"evaluation.csv": nil,
"evaluation.log": func(t *testing.T, filePath, data string) {
// Since the model is non-deterministic, we can only assert that the model did at least not error.
assert.Contains(t, data, `Evaluation score for "custom-ollama/qwen:0.5b"`)
assert.Contains(t, data, "response-no-error=1")
},
"golang-summed.csv": nil,
"models-summed.csv": nil,
"README.md": nil,
"custom-ollama_qwen:0.5b/golang/golang/plain.log": nil,
},
})
}
})
})

// This case checks a beautiful bug where the Markdown export crashed when the current working directory contained a README.md file. While this is not the case during the tests (as the current work directory is the directory of this file), it certainly caused problems when our binary was executed from the repository root (which of course contained a README.md). Therefore, we sadly have to modify the current work directory right within the tests of this case to reproduce the problem and fix it forever.
Expand Down
96 changes: 96 additions & 0 deletions provider/openai-api/openai.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
package openaiapi

import (
"context"
"fmt"
"strings"

pkgerrors "github.com/pkg/errors"
"github.com/sashabaranov/go-openai"

"github.com/symflower/eval-dev-quality/log"
"github.com/symflower/eval-dev-quality/model"
"github.com/symflower/eval-dev-quality/provider"
)

// Provider holds a generic "OpenAI API" provider.
type Provider struct {
baseURL string
token string
id string
models []model.Model
}

// NewProvider returns a generic "OpenAI API" provider.
func NewProvider(id string, baseURL string) (provider *Provider) {
return &Provider{
baseURL: baseURL,
id: id,
}
}

var _ provider.Provider = (*Provider)(nil)

// Available checks if the provider is ready to be used.
// This might include checking for an installation or making sure an API access token is valid.
func (p *Provider) Available(logger *log.Logger) (err error) {
return nil // We cannot know if a custom provider requires an API.
}

// ID returns the unique ID of this provider.
func (p *Provider) ID() (id string) {
return p.id
}

// Models returns which models are available to be queried via this provider.
func (p *Provider) Models() (models []model.Model, err error) {
return p.models, nil
}

// AddModel manually adds a model to the provider.
func (p *Provider) AddModel(m model.Model) {
p.models = append(p.models, m)
}

var _ provider.InjectToken = (*Provider)(nil)

// SetToken sets a potential token to be used in case the provider needs to authenticate a remote API.
func (p *Provider) SetToken(token string) {
p.token = token
}

var _ provider.Query = (*Provider)(nil)

// Query queries the provider with the given model name.
func (p *Provider) Query(ctx context.Context, modelIdentifier string, promptText string) (response string, err error) {
client := p.client()
modelIdentifier = strings.TrimPrefix(modelIdentifier, p.ID()+provider.ProviderModelSeparator)

apiResponse, err := client.CreateChatCompletion(
ctx,
openai.ChatCompletionRequest{
Model: modelIdentifier,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
Content: promptText,
},
},
},
)
if err != nil {
return "", pkgerrors.WithStack(err)
} else if len(apiResponse.Choices) == 0 {
return "", pkgerrors.WithStack(fmt.Errorf("empty LLM %q response: %+v", modelIdentifier, apiResponse))
}

return apiResponse.Choices[0].Message.Content, nil
}

// client returns a new client with the current configuration.
func (p *Provider) client() (client *openai.Client) {
config := openai.DefaultConfig(p.token)
config.BaseURL = p.baseURL

return openai.NewClientWithConfig(config)
}

0 comments on commit ab3da39

Please sign in to comment.