Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generic OpenAI API provider #112

Merged
merged 4 commits into from
May 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 36 additions & 3 deletions cmd/eval-dev-quality/cmd/evaluate.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@ import (
_ "github.com/symflower/eval-dev-quality/language/java" // Register language.
"github.com/symflower/eval-dev-quality/log"
"github.com/symflower/eval-dev-quality/model"
"github.com/symflower/eval-dev-quality/model/llm"
"github.com/symflower/eval-dev-quality/provider"
_ "github.com/symflower/eval-dev-quality/provider/ollama" // Register provider.
_ "github.com/symflower/eval-dev-quality/provider/ollama" // Register provider.
openaiapi "github.com/symflower/eval-dev-quality/provider/openai-api"
_ "github.com/symflower/eval-dev-quality/provider/openrouter" // Register provider.
_ "github.com/symflower/eval-dev-quality/provider/symflower" // Register provider.
"github.com/symflower/eval-dev-quality/tools"
Expand All @@ -40,11 +42,12 @@ type Evaluate struct {

// Languages determines which language should be used for the evaluation, or empty if all languages should be used.
Languages []string `long:"language" description:"Evaluate with this language. By default all languages are used."`

// Models determines which models should be used for the evaluation, or empty if all models should be used.
Models []string `long:"model" description:"Evaluate with this model. By default all models are used."`
// ProviderTokens holds all API tokens for the providers.
ProviderTokens map[string]string `long:"tokens" description:"API tokens for model providers (of the form '$provider:$token,...')." env:"PROVIDER_TOKEN"`
ProviderTokens map[string]string `long:"tokens" description:"API tokens for model providers (of the form '$provider:$token'). When using the environment variable, separate multiple definitions with ','." env:"PROVIDER_TOKEN" env-delim:","`
// ProviderUrls holds all custom inference endpoint urls for the providers.
ProviderUrls map[string]string `long:"urls" description:"Custom OpenAI API compatible inference endpoints (of the form '$provider:$url,...'). Use '$provider=custom-$name' to manually register a custom OpenAI API endpoint provider. Note that the models of a custom OpenAI API endpoint provider must be declared explicitly using the '--model' option. When using the environment variable, separate multiple definitions with ','." env:"PROVIDER_URL" env-delim:","`
// QueryAttempts holds the number of query attempts to perform when a model request errors in the process of solving a task.
QueryAttempts uint `long:"attempts" description:"Number of query attempts to perform when a model request errors in the process of solving a task." default:"3"`

Expand Down Expand Up @@ -112,6 +115,36 @@ func (command *Evaluate) Execute(args []string) (err error) {
}
}

// Register custom OpenAI API providers and models.
{
customProviders := map[string]*openaiapi.Provider{}
for providerID, providerURL := range command.ProviderUrls {
if !strings.HasPrefix(providerID, "custom-") {
continue
}

p := openaiapi.NewProvider(providerID, providerURL)
provider.Register(p)
customProviders[providerID] = p
}
for _, model := range command.Models {
if !strings.HasPrefix(model, "custom-") {
continue
}

providerID, _, ok := strings.Cut(model, provider.ProviderModelSeparator)
if !ok {
log.Panicf("ERROR: cannot split %q into provider and model name by %q", model, provider.ProviderModelSeparator)
}
modelProvider, ok := customProviders[providerID]
if !ok {
log.Panicf("ERROR: unknown custom provider %q for model %q", providerID, model)
}

modelProvider.AddModel(llm.NewModel(modelProvider, model))
}
}

// Gather languages.
languagesSelected := map[string]language.Language{}
{
Expand Down
58 changes: 54 additions & 4 deletions cmd/eval-dev-quality/cmd/evaluate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package cmd

import (
"fmt"
"net/url"
"os"
"path/filepath"
"regexp"
Expand All @@ -17,6 +18,7 @@ import (
"github.com/symflower/eval-dev-quality/evaluate/metrics"
metricstesting "github.com/symflower/eval-dev-quality/evaluate/metrics/testing"
"github.com/symflower/eval-dev-quality/log"
providertesting "github.com/symflower/eval-dev-quality/provider/testing"
"github.com/symflower/eval-dev-quality/tools"
)

Expand Down Expand Up @@ -501,12 +503,12 @@ func TestEvaluateExecute(t *testing.T) {
shutdown, err = tools.OllamaStart(logger, tools.OllamaPath, tools.OllamaURL)
require.NoError(t, err)

require.NoError(t, tools.OllamaPull(logger, tools.OllamaPath, tools.OllamaURL, "qwen:0.5b"))
require.NoError(t, tools.OllamaPull(logger, tools.OllamaPath, tools.OllamaURL, providertesting.OllamaTestModel))
},

Arguments: []string{
"--language", "golang",
"--model", "ollama/qwen:0.5b",
"--model", "ollama/" + providertesting.OllamaTestModel,
"--repository", filepath.Join("golang", "plain"),
},

Expand All @@ -515,15 +517,63 @@ func TestEvaluateExecute(t *testing.T) {
"evaluation.csv": nil,
"evaluation.log": func(t *testing.T, filePath, data string) {
// Since the model is non-deterministic, we can only assert that the model did at least not error.
assert.Contains(t, data, `Evaluation score for "ollama/qwen:0.5b"`)
assert.Contains(t, data, fmt.Sprintf(`Evaluation score for "ollama/%s"`, providertesting.OllamaTestModel))
assert.Contains(t, data, "response-no-error=1")
assert.Contains(t, data, "preloading model")
assert.Contains(t, data, "unloading model")
},
"golang-summed.csv": nil,
"models-summed.csv": nil,
"README.md": nil,
"ollama_qwen:0.5b/golang/golang/plain.log": nil,
"ollama_" + providertesting.OllamaTestModel + "/golang/golang/plain.log": nil,
},
})
}
})
t.Run("OpenAI API", func(t *testing.T) {
if !osutil.IsLinux() {
t.Skipf("Installation of Ollama is not supported on this OS")
}

{
var shutdown func() (err error)
defer func() {
if shutdown != nil {
require.NoError(t, shutdown())
}
}()
ollamaOpenAIAPIUrl, err := url.JoinPath(tools.OllamaURL, "v1")
require.NoError(t, err)
validate(t, &testCase{
Name: "Ollama",

Before: func(t *testing.T, logger *log.Logger, resultPath string) {
var err error
shutdown, err = tools.OllamaStart(logger, tools.OllamaPath, tools.OllamaURL)
require.NoError(t, err)

require.NoError(t, tools.OllamaPull(logger, tools.OllamaPath, tools.OllamaURL, providertesting.OllamaTestModel))
},

Arguments: []string{
"--language", "golang",
"--urls", "custom-ollama:" + ollamaOpenAIAPIUrl,
bauersimon marked this conversation as resolved.
Show resolved Hide resolved
"--model", "custom-ollama/" + providertesting.OllamaTestModel,
"--repository", filepath.Join("golang", "plain"),
},

ExpectedResultFiles: map[string]func(t *testing.T, filePath string, data string){
"categories.svg": nil,
"evaluation.csv": nil,
"evaluation.log": func(t *testing.T, filePath, data string) {
// Since the model is non-deterministic, we can only assert that the model did at least not error.
assert.Contains(t, data, fmt.Sprintf(`Evaluation score for "custom-ollama/%s"`, providertesting.OllamaTestModel))
assert.Contains(t, data, "response-no-error=1")
},
"golang-summed.csv": nil,
"models-summed.csv": nil,
"README.md": nil,
"custom-ollama_" + providertesting.OllamaTestModel + "/golang/golang/plain.log": nil,
},
})
}
Expand Down
22 changes: 2 additions & 20 deletions provider/ollama/ollama.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package ollama

import (
"context"
"fmt"
"net/url"
"strings"

Expand All @@ -13,6 +12,7 @@ import (
"github.com/symflower/eval-dev-quality/model"
"github.com/symflower/eval-dev-quality/model/llm"
"github.com/symflower/eval-dev-quality/provider"
openaiapi "github.com/symflower/eval-dev-quality/provider/openai-api"
"github.com/symflower/eval-dev-quality/tools"
)

Expand Down Expand Up @@ -85,25 +85,7 @@ func (p *Provider) Query(ctx context.Context, modelIdentifier string, promptText
client := p.client()
modelIdentifier = strings.TrimPrefix(modelIdentifier, p.ID()+provider.ProviderModelSeparator)

apiResponse, err := client.CreateChatCompletion(
ctx,
openai.ChatCompletionRequest{
Model: modelIdentifier,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
Content: promptText,
},
},
},
)
if err != nil {
return "", pkgerrors.WithStack(err)
} else if len(apiResponse.Choices) == 0 {
return "", pkgerrors.WithStack(fmt.Errorf("empty LLM %q response: %+v", modelIdentifier, apiResponse))
}

return apiResponse.Choices[0].Message.Content, nil
return openaiapi.QueryOpenAIAPIModel(ctx, client, modelIdentifier, promptText)
}

// client returns a new client with the current configuration.
Expand Down
5 changes: 3 additions & 2 deletions provider/ollama/ollama_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/zimmski/osutil"

"github.com/symflower/eval-dev-quality/log"
providertesting "github.com/symflower/eval-dev-quality/provider/testing"
"github.com/symflower/eval-dev-quality/tools"
)

Expand Down Expand Up @@ -143,11 +144,11 @@ func TestProviderModels(t *testing.T) {
Name: "Local Model",

LocalModels: []string{
"qwen:0.5b",
providertesting.OllamaTestModel,
},

ExpectedModels: []string{
"ollama/qwen:0.5b",
"ollama/" + providertesting.OllamaTestModel,
},
})
}
76 changes: 76 additions & 0 deletions provider/openai-api/openai.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package openaiapi

import (
"context"
"strings"

"github.com/sashabaranov/go-openai"

"github.com/symflower/eval-dev-quality/log"
"github.com/symflower/eval-dev-quality/model"
"github.com/symflower/eval-dev-quality/provider"
)

// Provider holds a generic "OpenAI API" provider.
type Provider struct {
baseURL string
token string
id string
models []model.Model
}

// NewProvider returns a generic "OpenAI API" provider.
func NewProvider(id string, baseURL string) (provider *Provider) {
return &Provider{
baseURL: baseURL,
id: id,
}
}

var _ provider.Provider = (*Provider)(nil)

// Available checks if the provider is ready to be used.
// This might include checking for an installation or making sure an API access token is valid.
func (p *Provider) Available(logger *log.Logger) (err error) {
return nil // We cannot know if a custom provider requires an API.
}

// ID returns the unique ID of this provider.
func (p *Provider) ID() (id string) {
return p.id
}

// Models returns which models are available to be queried via this provider.
func (p *Provider) Models() (models []model.Model, err error) {
return p.models, nil
}

// AddModel manually adds a model to the provider.
func (p *Provider) AddModel(m model.Model) {
p.models = append(p.models, m)
}

var _ provider.InjectToken = (*Provider)(nil)

// SetToken sets a potential token to be used in case the provider needs to authenticate a remote API.
func (p *Provider) SetToken(token string) {
p.token = token
}

var _ provider.Query = (*Provider)(nil)

// Query queries the provider with the given model name.
func (p *Provider) Query(ctx context.Context, modelIdentifier string, promptText string) (response string, err error) {
client := p.client()
modelIdentifier = strings.TrimPrefix(modelIdentifier, p.ID()+provider.ProviderModelSeparator)

return QueryOpenAIAPIModel(ctx, client, modelIdentifier, promptText)
}

// client returns a new client with the current configuration.
func (p *Provider) client() (client *openai.Client) {
config := openai.DefaultConfig(p.token)
config.BaseURL = p.baseURL

return openai.NewClientWithConfig(config)
}
32 changes: 32 additions & 0 deletions provider/openai-api/query.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package openaiapi

import (
"context"
"fmt"

pkgerrors "github.com/pkg/errors"
"github.com/sashabaranov/go-openai"
)

// QueryOpenAIModel queries an OpenAI API model.
func QueryOpenAIAPIModel(ctx context.Context, client *openai.Client, modelIdentifier string, promptText string) (response string, err error) {
Munsio marked this conversation as resolved.
Show resolved Hide resolved
apiResponse, err := client.CreateChatCompletion(
ctx,
openai.ChatCompletionRequest{
Model: modelIdentifier,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
Content: promptText,
},
},
},
)
if err != nil {
return "", pkgerrors.WithStack(err)
} else if len(apiResponse.Choices) == 0 {
return "", pkgerrors.WithStack(fmt.Errorf("empty LLM %q response: %+v", modelIdentifier, apiResponse))
}

return apiResponse.Choices[0].Message.Content, nil
}
22 changes: 2 additions & 20 deletions provider/openrouter/openrouter.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package openrouter
import (
"context"
"errors"
"fmt"
"strings"

pkgerrors "github.com/pkg/errors"
Expand All @@ -13,6 +12,7 @@ import (
"github.com/symflower/eval-dev-quality/model"
"github.com/symflower/eval-dev-quality/model/llm"
"github.com/symflower/eval-dev-quality/provider"
openaiapi "github.com/symflower/eval-dev-quality/provider/openai-api"
)

// Provider holds an "openrouter.ai" provider using its public REST API.
Expand Down Expand Up @@ -79,25 +79,7 @@ func (p *Provider) Query(ctx context.Context, modelIdentifier string, promptText
client := p.client()
modelIdentifier = strings.TrimPrefix(modelIdentifier, p.ID()+provider.ProviderModelSeparator)

apiResponse, err := client.CreateChatCompletion(
ctx,
openai.ChatCompletionRequest{
Model: modelIdentifier,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
Content: promptText,
},
},
},
)
if err != nil {
return "", pkgerrors.WithStack(err)
} else if len(apiResponse.Choices) == 0 {
return "", pkgerrors.WithStack(fmt.Errorf("empty LLM %q response: %+v", modelIdentifier, apiResponse))
}

return apiResponse.Choices[0].Message.Content, nil
return openaiapi.QueryOpenAIAPIModel(ctx, client, modelIdentifier, promptText)
}

// client returns a new client with the current configuration.
Expand Down
3 changes: 3 additions & 0 deletions provider/testing/helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,6 @@ func NewMockProviderNamedWithModels(t *testing.T, id string, models []model.Mode

return m
}

// OllamaTestModel holds the smallest Ollama model that we use for testing.
const OllamaTestModel = "qwen:0.5b"
Loading
Loading