Skip to content

Commit

Permalink
🧱 unify azure config from env or yaml file
Browse files Browse the repository at this point in the history
  • Loading branch information
warjiang committed Jun 15, 2023
1 parent 8f6b8ef commit 932eefe
Showing 1 changed file with 77 additions and 27 deletions.
104 changes: 77 additions & 27 deletions azure/init.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
package azure

import (
"fmt"
"github.com/spf13/viper"
"github.com/stulzq/azure-openai-proxy/constant"
"github.com/stulzq/azure-openai-proxy/util"
"log"
"net/url"
"os"
"regexp"
"strings"
)

Expand All @@ -14,43 +15,92 @@ const (
)

var (
AzureOpenAIEndpoint = ""
AzureOpenAIEndpointParse *url.URL

AzureOpenAIAPIVer = ""

AzureOpenAIModelMapper = map[string]string{
"gpt-3.5-turbo": "gpt-35-turbo",
}
fallbackModelMapper = regexp.MustCompile(`[.:]`)
C Config
ModelDeploymentConfig = map[string]DeploymentConfig{}
)

func Init() {
AzureOpenAIAPIVer = os.Getenv(constant.ENV_AZURE_OPENAI_API_VER)
AzureOpenAIEndpoint = os.Getenv(constant.ENV_AZURE_OPENAI_ENDPOINT)
func Init() error {
var (
apiVersion string
endpoint string
openaiModelMapper string
err error
)

if AzureOpenAIAPIVer == "" {
AzureOpenAIAPIVer = "2023-03-15-preview"
apiVersion = viper.GetString(constant.ENV_AZURE_OPENAI_API_VER)
endpoint = viper.GetString(constant.ENV_AZURE_OPENAI_ENDPOINT)
openaiModelMapper = viper.GetString(constant.ENV_AZURE_OPENAI_MODEL_MAPPER)
if endpoint != "" && openaiModelMapper != "" {
if apiVersion == "" {
apiVersion = "2023-03-15-preview"
}
InitFromEnvironmentVariables(apiVersion, endpoint, openaiModelMapper)
} else {
if err = InitFromConfigFile(); err != nil {
return err
}
}

var err error
AzureOpenAIEndpointParse, err = url.Parse(AzureOpenAIEndpoint)
if err != nil {
log.Fatal("parse AzureOpenAIEndpoint error: ", err)
// ensure apiBase likes /v1
apiBase := viper.GetString("api_base")
if !strings.HasPrefix(apiBase, "/") {
apiBase = "/" + apiBase
}
if strings.HasSuffix(apiBase, "/") {
apiBase = apiBase[:len(apiBase)-1]
}
viper.Set("api_base", apiBase)
log.Printf("apiBase is: %s", apiBase)
for _, itemConfig := range C.DeploymentConfig {
u, err := url.Parse(itemConfig.Endpoint)
if err != nil {
return fmt.Errorf("parse endpoint error: %w", err)
}
itemConfig.EndpointUrl = u
ModelDeploymentConfig[itemConfig.ModelName] = itemConfig
}
return err
}

if v := os.Getenv(constant.ENV_AZURE_OPENAI_MODEL_MAPPER); v != "" {
for _, pair := range strings.Split(v, ",") {
func InitFromEnvironmentVariables(apiVersion, endpoint, openaiModelMapper string) {
log.Println("Init from environment variables")
if openaiModelMapper != "" {
// openaiModelMapper example:
// gpt-3.5-turbo=deployment_name_for_gpt_model,text-davinci-003=deployment_name_for_davinci_model
for _, pair := range strings.Split(openaiModelMapper, ",") {
info := strings.Split(pair, "=")
if len(info) != 2 {
log.Fatalf("error parsing %s, invalid value %s", constant.ENV_AZURE_OPENAI_MODEL_MAPPER, pair)
}

AzureOpenAIModelMapper[info[0]] = info[1]
modelName, deploymentName := info[0], info[1]
ModelDeploymentConfig[modelName] = DeploymentConfig{
DeploymentName: deploymentName,
ModelName: modelName,
Endpoint: endpoint,
ApiKey: "",
ApiVersion: apiVersion,
}
}
}
}

func InitFromConfigFile() error {
log.Println("Init from config file")
workDir := util.GetWorkdir()
viper.SetConfigName("config")
viper.SetConfigType("yaml")
viper.AddConfigPath(fmt.Sprintf("%s/config", workDir))
if err := viper.ReadInConfig(); err != nil {
log.Printf("read config file error: %+v\n", err)
return err
}

log.Println("AzureOpenAIAPIVer: ", AzureOpenAIAPIVer)
log.Println("AzureOpenAIEndpoint: ", AzureOpenAIEndpoint)
log.Println("AzureOpenAIModelMapper: ", AzureOpenAIModelMapper)
if err := viper.Unmarshal(&C); err != nil {
log.Printf("unmarshal config file error: %+v\n", err)
return err
}
for _, configItem := range C.DeploymentConfig {
ModelDeploymentConfig[configItem.ModelName] = configItem
}
return nil
}

0 comments on commit 932eefe

Please sign in to comment.