Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for baseURL in generative-openai azure config #4124

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
14 changes: 9 additions & 5 deletions modules/generative-openai/clients/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,15 @@ import (

var compile, _ = regexp.Compile(`{([\w\s]*?)}`)

func buildUrlFn(isLegacy bool, resourceName, deploymentID, baseURL string) (string, error) {
func buildUrlFn(isLegacy bool, resourceName, deploymentID, baseURL, apiVersion string) (string, error) {
if resourceName != "" && deploymentID != "" {
host := "https://" + resourceName + ".openai.azure.com"
host := baseURL
if host == "" || host == "https://api.openai.com" {
// Fall back to old assumption
host = "https://" + resourceName + ".openai.azure.com"
}
path := "openai/deployments/" + deploymentID + "/chat/completions"
queryParam := "api-version=2023-03-15-preview"
queryParam := fmt.Sprintf("api-version=%s", apiVersion)
return fmt.Sprintf("%s/%s?%s", host, path, queryParam), nil
}
path := "/v1/chat/completions"
Expand All @@ -53,7 +57,7 @@ type openai struct {
openAIApiKey string
openAIOrganization string
azureApiKey string
buildUrl func(isLegacy bool, resourceName, deploymentID, baseURL string) (string, error)
buildUrl func(isLegacy bool, resourceName, deploymentID, baseURL, apiVersion string) (string, error)
httpClient *http.Client
logger logrus.FieldLogger
}
Expand Down Expand Up @@ -167,7 +171,7 @@ func (v *openai) buildOpenAIUrl(ctx context.Context, settings config.ClassSettin
if headerBaseURL := v.getValueFromContext(ctx, "X-Openai-Baseurl"); headerBaseURL != "" {
baseURL = headerBaseURL
}
return v.buildUrl(settings.IsLegacy(), settings.ResourceName(), settings.DeploymentID(), baseURL)
return v.buildUrl(settings.IsLegacy(), settings.ResourceName(), settings.DeploymentID(), baseURL, settings.ApiVersion())
}

func (v *openai) generateInput(prompt string, settings config.ClassSettings) (generateInput, error) {
Expand Down
31 changes: 20 additions & 11 deletions modules/generative-openai/clients/openai_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ func nullLogger() logrus.FieldLogger {
return l
}

func fakeBuildUrl(serverURL string, isLegacy bool, resourceName, deploymentID, baseURL string) (string, error) {
endpoint, err := buildUrlFn(isLegacy, resourceName, deploymentID, baseURL)
func fakeBuildUrl(serverURL string, isLegacy bool, resourceName, deploymentID, baseURL, apiVersion string) (string, error) {
endpoint, err := buildUrlFn(isLegacy, resourceName, deploymentID, baseURL, apiVersion)
if err != nil {
return "", err
}
Expand All @@ -46,22 +46,26 @@ func fakeBuildUrl(serverURL string, isLegacy bool, resourceName, deploymentID, b

func TestBuildUrlFn(t *testing.T) {
t.Run("buildUrlFn returns default OpenAI Client", func(t *testing.T) {
url, err := buildUrlFn(false, "", "", config.DefaultOpenAIBaseURL)
url, err := buildUrlFn(false, "", "", config.DefaultOpenAIBaseURL, config.DefaultApiVersion)
assert.Nil(t, err)
assert.Equal(t, "https://api.openai.com/v1/chat/completions", url)
})
t.Run("buildUrlFn returns Azure Client", func(t *testing.T) {
url, err := buildUrlFn(false, "resourceID", "deploymentID", "")
url, err := buildUrlFn(false, "resourceID", "deploymentID", "", config.DefaultApiVersion)
assert.Nil(t, err)
assert.Equal(t, "https://resourceID.openai.azure.com/openai/deployments/deploymentID/chat/completions?api-version=2023-03-15-preview", url)
assert.Equal(t, "https://resourceID.openai.azure.com/openai/deployments/deploymentID/chat/completions?api-version=2023-05-15", url)
})

t.Run("buildUrlFn loads from environment variable", func(t *testing.T) {
url, err := buildUrlFn(false, "", "", "https://foobar.some.proxy")
url, err := buildUrlFn(false, "", "", "https://foobar.some.proxy", config.DefaultApiVersion)
assert.Nil(t, err)
assert.Equal(t, "https://foobar.some.proxy/v1/chat/completions", url)
os.Unsetenv("OPENAI_BASE_URL")
})
t.Run("buildUrlFn returns Azure Client with custom baseURL", func(t *testing.T) {
url, err := buildUrlFn(false, "resourceID", "deploymentID", "customBaseURL", config.DefaultApiVersion)
assert.Nil(t, err)
assert.Equal(t, "customBaseURL/openai/deployments/deploymentID/chat/completions?api-version=2023-05-15", url)
})
}

func TestGetAnswer(t *testing.T) {
Expand All @@ -83,8 +87,8 @@ func TestGetAnswer(t *testing.T) {
defer server.Close()

c := New("openAIApiKey", "", "", 0, nullLogger())
c.buildUrl = func(isLegacy bool, resourceName, deploymentID, baseURL string) (string, error) {
return fakeBuildUrl(server.URL, isLegacy, resourceName, deploymentID, baseURL)
c.buildUrl = func(isLegacy bool, resourceName, deploymentID, baseURL, apiVersion string) (string, error) {
return fakeBuildUrl(server.URL, isLegacy, resourceName, deploymentID, baseURL, apiVersion)
}

expected := generativemodels.GenerateResponse{
Expand All @@ -109,8 +113,8 @@ func TestGetAnswer(t *testing.T) {
defer server.Close()

c := New("openAIApiKey", "", "", 0, nullLogger())
c.buildUrl = func(isLegacy bool, resourceName, deploymentID, baseURL string) (string, error) {
return fakeBuildUrl(server.URL, isLegacy, resourceName, deploymentID, baseURL)
c.buildUrl = func(isLegacy bool, resourceName, deploymentID, baseURL, apiVersion string) (string, error) {
return fakeBuildUrl(server.URL, isLegacy, resourceName, deploymentID, baseURL, apiVersion)
}

_, err := c.GenerateAllResults(context.Background(), textProperties, "What is my name?", nil)
Expand Down Expand Up @@ -239,6 +243,7 @@ type fakeClassSettings struct {
deploymentID string
isAzure bool
baseURL string
apiVersion string
}

func (s *fakeClassSettings) IsLegacy() bool {
Expand Down Expand Up @@ -292,3 +297,7 @@ func (s *fakeClassSettings) Validate(class *models.Class) error {
func (s *fakeClassSettings) BaseURL() string {
return s.baseURL
}

func (s *fakeClassSettings) ApiVersion() string {
return s.apiVersion
}
27 changes: 27 additions & 0 deletions modules/generative-openai/config/class_settings.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ const (
presencePenaltyProperty = "presencePenalty"
topPProperty = "topP"
baseURLProperty = "baseURL"
apiVersionProperty = "apiVersion"
)

var availableOpenAILegacyModels = []string{
Expand All @@ -52,6 +53,7 @@ var (
DefaultOpenAIPresencePenalty = 0.0
DefaultOpenAITopP = 1.0
DefaultOpenAIBaseURL = "https://api.openai.com"
DefaultApiVersion = "2023-05-15"
)

// todo Need to parse the tokenLimits in a smarter way, as the prompt defines the max length
Expand All @@ -66,6 +68,17 @@ var defaultMaxTokens = map[string]float64{
"gpt-4-1106-preview": 128000,
}

var availableApiVersions = []string{
"2022-12-01",
"2023-03-15-preview",
"2023-05-15",
"2023-06-01-preview",
"2023-07-01-preview",
"2023-08-01-preview",
"2023-09-01-preview",
"2023-12-01-preview",
}

type ClassSettings interface {
IsLegacy() bool
Model() string
Expand All @@ -80,6 +93,7 @@ type ClassSettings interface {
GetMaxTokensForModel(model string) float64
Validate(class *models.Class) error
BaseURL() string
ApiVersion() string
}

type classSettings struct {
Expand Down Expand Up @@ -126,6 +140,11 @@ func (ic *classSettings) Validate(class *models.Class) error {
return errors.Errorf("Wrong topP configuration, values are should have a minimal value of 1 and max of 5")
}

apiVersion := ic.ApiVersion()
if !ic.validateApiVersion(apiVersion) {
return errors.Errorf("wrong Azure OpenAI apiVersion, available api versions are: %v", availableApiVersions)
}

err := ic.validateAzureConfig(ic.ResourceName(), ic.DeploymentID())
if err != nil {
return err
Expand Down Expand Up @@ -192,6 +211,10 @@ func (ic *classSettings) validateModel(model string) bool {
return contains(availableOpenAIModels, model) || contains(availableOpenAILegacyModels, model)
}

func (ic *classSettings) validateApiVersion(apiVersion string) bool {
return contains(availableApiVersions, apiVersion)
}

func (ic *classSettings) IsLegacy() bool {
return contains(availableOpenAILegacyModels, ic.Model())
}
Expand All @@ -208,6 +231,10 @@ func (ic *classSettings) BaseURL() string {
return *ic.getStringProperty(baseURLProperty, DefaultOpenAIBaseURL)
}

func (ic *classSettings) ApiVersion() string {
return *ic.getStringProperty(apiVersionProperty, DefaultApiVersion)
}

func (ic *classSettings) Temperature() float64 {
return *ic.getFloatProperty(temperatureProperty, &DefaultOpenAITemperature)
}
Expand Down
46 changes: 46 additions & 0 deletions modules/generative-openai/config/class_settings_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ func Test_classSettings_Validate(t *testing.T) {
wantIsAzure bool
wantErr error
wantBaseURL string
wantApiVersion string
}{
{
name: "Happy flow",
Expand All @@ -48,6 +49,7 @@ func Test_classSettings_Validate(t *testing.T) {
wantPresencePenalty: 0.0,
wantErr: nil,
wantBaseURL: "https://api.openai.com",
wantApiVersion: "2023-05-15",
},
{
name: "Everything non default configured",
Expand All @@ -69,6 +71,7 @@ func Test_classSettings_Validate(t *testing.T) {
wantPresencePenalty: 0.9,
wantErr: nil,
wantBaseURL: "https://api.openai.com",
wantApiVersion: "2023-05-15",
},
{
name: "OpenAI Proxy",
Expand All @@ -84,6 +87,7 @@ func Test_classSettings_Validate(t *testing.T) {
},
},
wantBaseURL: "https://proxy.weaviate.dev/",
wantApiVersion: "2023-05-15",
wantModel: "gpt-3.5-turbo",
wantMaxTokens: 4097,
wantTemperature: 0.5,
Expand Down Expand Up @@ -112,6 +116,7 @@ func Test_classSettings_Validate(t *testing.T) {
wantPresencePenalty: 0.9,
wantErr: nil,
wantBaseURL: "https://api.openai.com",
wantApiVersion: "2023-05-15",
},
{
name: "Azure OpenAI config",
Expand All @@ -137,6 +142,34 @@ func Test_classSettings_Validate(t *testing.T) {
wantPresencePenalty: 0.9,
wantErr: nil,
wantBaseURL: "https://api.openai.com",
wantApiVersion: "2023-05-15",
},
{
name: "Azure OpenAI config with baseURL",
cfg: fakeClassConfig{
classConfig: map[string]interface{}{
"baseURL": "some-base-url",
"resourceName": "weaviate",
"deploymentId": "gpt-3.5-turbo",
"maxTokens": 4097,
"temperature": 0.5,
"topP": 3,
"frequencyPenalty": 0.1,
"presencePenalty": 0.9,
},
},
wantResourceName: "weaviate",
wantDeploymentID: "gpt-3.5-turbo",
wantIsAzure: true,
wantModel: "gpt-3.5-turbo",
wantMaxTokens: 4097,
wantTemperature: 0.5,
wantTopP: 3,
wantFrequencyPenalty: 0.1,
wantPresencePenalty: 0.9,
wantErr: nil,
wantBaseURL: "some-base-url",
wantApiVersion: "2023-05-15",
},
{
name: "With gpt-3.5-turbo-16k model",
Expand All @@ -158,6 +191,7 @@ func Test_classSettings_Validate(t *testing.T) {
wantPresencePenalty: 0.9,
wantErr: nil,
wantBaseURL: "https://api.openai.com",
wantApiVersion: "2023-05-15",
},
{
name: "Wrong maxTokens configured",
Expand Down Expand Up @@ -222,6 +256,17 @@ func Test_classSettings_Validate(t *testing.T) {
},
wantErr: errors.Errorf("both resourceName and deploymentId must be provided"),
},
{
name: "Wrong Azure config - wrong api version",
cfg: fakeClassConfig{
classConfig: map[string]interface{}{
"apiVersion": "wrong-api-version",
},
},
wantErr: errors.Errorf("wrong Azure OpenAI apiVersion, available api versions are: " +
"[2022-12-01 2023-03-15-preview 2023-05-15 2023-06-01-preview 2023-07-01-preview " +
"2023-08-01-preview 2023-09-01-preview 2023-12-01-preview]"),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand All @@ -239,6 +284,7 @@ func Test_classSettings_Validate(t *testing.T) {
assert.Equal(t, tt.wantDeploymentID, ic.DeploymentID())
assert.Equal(t, tt.wantIsAzure, ic.IsAzure())
assert.Equal(t, tt.wantBaseURL, ic.BaseURL())
assert.Equal(t, tt.wantApiVersion, ic.ApiVersion())
}
})
}
Expand Down