-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Generic OpenAI API provider that can be configured via command line
Part of #111
- Loading branch information
1 parent
0d5b237
commit ab3da39
Showing
3 changed files
with
180 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
package openaiapi | ||
|
||
import ( | ||
"context" | ||
"fmt" | ||
"strings" | ||
|
||
pkgerrors "github.com/pkg/errors" | ||
"github.com/sashabaranov/go-openai" | ||
|
||
"github.com/symflower/eval-dev-quality/log" | ||
"github.com/symflower/eval-dev-quality/model" | ||
"github.com/symflower/eval-dev-quality/provider" | ||
) | ||
|
||
// Provider holds a generic "OpenAI API" provider. | ||
type Provider struct { | ||
baseURL string | ||
token string | ||
id string | ||
models []model.Model | ||
} | ||
|
||
// NewProvider returns a generic "OpenAI API" provider. | ||
func NewProvider(id string, baseURL string) (provider *Provider) { | ||
return &Provider{ | ||
baseURL: baseURL, | ||
id: id, | ||
} | ||
} | ||
|
||
var _ provider.Provider = (*Provider)(nil) | ||
|
||
// Available checks if the provider is ready to be used. | ||
// This might include checking for an installation or making sure an API access token is valid. | ||
func (p *Provider) Available(logger *log.Logger) (err error) { | ||
return nil // We cannot know if a custom provider requires an API. | ||
} | ||
|
||
// ID returns the unique ID of this provider. | ||
func (p *Provider) ID() (id string) { | ||
return p.id | ||
} | ||
|
||
// Models returns which models are available to be queried via this provider. | ||
func (p *Provider) Models() (models []model.Model, err error) { | ||
return p.models, nil | ||
} | ||
|
||
// AddModel manually adds a model to the provider. | ||
func (p *Provider) AddModel(m model.Model) { | ||
p.models = append(p.models, m) | ||
} | ||
|
||
var _ provider.InjectToken = (*Provider)(nil) | ||
|
||
// SetToken sets a potential token to be used in case the provider needs to authenticate a remote API. | ||
func (p *Provider) SetToken(token string) { | ||
p.token = token | ||
} | ||
|
||
var _ provider.Query = (*Provider)(nil) | ||
|
||
// Query queries the provider with the given model name. | ||
func (p *Provider) Query(ctx context.Context, modelIdentifier string, promptText string) (response string, err error) { | ||
client := p.client() | ||
modelIdentifier = strings.TrimPrefix(modelIdentifier, p.ID()+provider.ProviderModelSeparator) | ||
|
||
apiResponse, err := client.CreateChatCompletion( | ||
ctx, | ||
openai.ChatCompletionRequest{ | ||
Model: modelIdentifier, | ||
Messages: []openai.ChatCompletionMessage{ | ||
{ | ||
Role: openai.ChatMessageRoleUser, | ||
Content: promptText, | ||
}, | ||
}, | ||
}, | ||
) | ||
if err != nil { | ||
return "", pkgerrors.WithStack(err) | ||
} else if len(apiResponse.Choices) == 0 { | ||
return "", pkgerrors.WithStack(fmt.Errorf("empty LLM %q response: %+v", modelIdentifier, apiResponse)) | ||
} | ||
|
||
return apiResponse.Choices[0].Message.Content, nil | ||
} | ||
|
||
// client returns a new client with the current configuration. | ||
func (p *Provider) client() (client *openai.Client) { | ||
config := openai.DefaultConfig(p.token) | ||
config.BaseURL = p.baseURL | ||
|
||
return openai.NewClientWithConfig(config) | ||
} |