Skip to content

Commit

Permalink
add model store; add support for openai's gpt-4
Browse files Browse the repository at this point in the history
  • Loading branch information
lyricat committed Mar 27, 2023
1 parent 311dd6b commit 71617ea
Show file tree
Hide file tree
Showing 12 changed files with 226 additions and 50 deletions.
6 changes: 4 additions & 2 deletions cmd/httpd/httpd.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
"github.com/pandodao/botastic/store/bot"
"github.com/pandodao/botastic/store/conv"
"github.com/pandodao/botastic/store/index"
"github.com/pandodao/botastic/store/model"
"github.com/pandodao/botastic/store/order"
"github.com/pandodao/botastic/store/user"
"github.com/pandodao/botastic/worker"
Expand Down Expand Up @@ -85,11 +86,12 @@ func NewCmdHttpd() *cobra.Command {
return err
}
indexes := index.New(ctx, milvusClient)
models := model.New()

userz := userServ.New(userServ.Config{
ExtraRate: cfg.Sys.ExtraRate,
InitUserCredits: cfg.Sys.InitUserCredits,
}, client, users)
}, client, users, models)
indexService := indexServ.NewService(ctx, gptHandler, indexes, userz)

middlewarez := middlewareServ.New(middlewareServ.Config{}, indexService)
Expand All @@ -109,7 +111,7 @@ func NewCmdHttpd() *cobra.Command {
// httpd's workers
workers := []worker.Worker{
// rotater
rotater.New(rotater.Config{}, gptHandler, convs, apps, convz, botz, middlewarez, userz, hub),
rotater.New(rotater.Config{}, gptHandler, convs, apps, models, convz, botz, middlewarez, userz, hub),

ordersyncer.New(ordersyncer.Config{
Interval: cfg.OrderSyncer.Interval,
Expand Down
2 changes: 1 addition & 1 deletion core/bot.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (
"text/template"
"time"

gogpt "github.com/sashabaranov/go-gpt3"
gogpt "github.com/sashabaranov/go-openai"
)

type JSONB json.RawMessage
Expand Down
52 changes: 52 additions & 0 deletions core/model.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package core

import (
"context"

"github.com/shopspring/decimal"
)

const (
ModelProviderOpenAI = "openai"
)

const (
ModelOpenAIGPT4 = "openai:gpt-4"
ModelOpenAIGPT3Dot5Turbo = "openai:gpt-3.5-turbo"
ModelOpenAIGPT3TextDavinci003 = "openai:text-davinci-003"
ModelOpenAIAdaEmbeddingV2 = "openai:text-embedding-ada-002"
)

type (
Model struct {
Provider string `yaml:"provider"`
ProviderModel string `yaml:"provider_model"`
MaxToken int `yaml:"max_token"`
PromptPriceUSD decimal.Decimal `yaml:"prompt_price_usd"`
CompletionPriceUSD decimal.Decimal `yaml:"completion_price_usd"`
PriceUSD decimal.Decimal `yaml:"price_usd"`

Props struct {
IsOpenAIChatModel bool `yaml:"is_openai_chat_model"`
IsOpenAICompletionModel bool `yaml:"is_openai_completion_model"`
IsOpenAIEmbeddingModel bool `yaml:"is_openai_embedding_model"`
}
}

ModelStore interface {
GetModel(ctx context.Context, name string) (*Model, error)
}
)

func (m *Model) CalculateTokenCost(promptCount, completionCount int64) decimal.Decimal {
pc := decimal.NewFromInt(promptCount)
cc := decimal.NewFromInt(completionCount)

if m.PriceUSD.IsPositive() {
return m.PriceUSD.Mul(pc.Add(cc))
}
if m.PromptPriceUSD.IsPositive() && m.CompletionPriceUSD.IsPositive() {
return m.PromptPriceUSD.Mul(pc).Add(m.CompletionPriceUSD.Mul(cc))
}
return decimal.Zero
}
3 changes: 2 additions & 1 deletion core/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ type (
LoginWithMixin(ctx context.Context, token, pubkey, lang string) (*User, error)
Topup(ctx context.Context, user *User, amount decimal.Decimal) error
ConsumeCredits(ctx context.Context, userID uint64, amount decimal.Decimal) error
ConsumeCreditsByModel(ctx context.Context, userID uint64, model string, amount uint64) error
// ConsumeCreditsByModel(ctx context.Context, userID uint64, model string, amount uint64) error
ConsumeCreditsByModel(ctx context.Context, userID uint64, model string, promptTokenCount, completionTokenCount int64) error
ReplaceStore(store UserStore) UserService
}
)
Expand Down
4 changes: 2 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ require (
github.com/pandodao/tokenizer-go v0.0.1
github.com/pressly/goose/v3 v3.9.0
github.com/rs/cors v1.8.3
github.com/sashabaranov/go-gpt3 v1.3.3
github.com/sashabaranov/go-openai v1.5.7
github.com/sirupsen/logrus v1.9.0
github.com/spf13/cobra v1.6.1
golang.org/x/sync v0.1.0
Expand Down Expand Up @@ -91,7 +91,7 @@ require (
github.com/go-playground/validator/v10 v10.11.1 // indirect
github.com/go-sql-driver/mysql v1.7.0 // indirect
github.com/goccy/go-json v0.9.11 // indirect
github.com/gofrs/uuid v4.4.0+incompatible // indirect
github.com/gofrs/uuid v4.4.0+incompatible
github.com/golang-jwt/jwt/v4 v4.5.0
github.com/inconshreveable/mousetrap v1.0.1 // indirect
github.com/jackc/chunkreader/v2 v2.0.1 // indirect
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -519,8 +519,8 @@ github.com/rs/zerolog v1.15.0/go.mod h1:xYTKnLHcpfU2225ny5qZjxnj9NvkumZYjJHlAThC
github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/ryanuber/columnize v2.1.0+incompatible/go.mod h1:sm1tb6uqfes/u+d4ooFouqFdy9/2g9QGwK3SQygK0Ts=
github.com/sashabaranov/go-gpt3 v1.3.3 h1:S8Zd4YybnBaNMK+w+XGGWgsjQY1R+6QE2n9SLzVna9k=
github.com/sashabaranov/go-gpt3 v1.3.3/go.mod h1:BIZdbwdzxZbCrcKGMGH6u2eyGe1xFuX9Anmh3tCP8lQ=
github.com/sashabaranov/go-openai v1.5.7 h1:8DGgRG+P7yWixte5j720y6yiXgY3Hlgcd0gcpHdltfo=
github.com/sashabaranov/go-openai v1.5.7/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0=
github.com/schollz/closestmatch v2.1.0+incompatible/go.mod h1:RtP1ddjLong6gTkbtmuhtR2uUrrJOpYzYRvbcPAid+g=
github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo=
Expand Down
2 changes: 1 addition & 1 deletion internal/gpt/gpt.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
"time"

"github.com/fox-one/pkg/logger"
gogpt "github.com/sashabaranov/go-gpt3"
gogpt "github.com/sashabaranov/go-openai"
)

var ErrTooManyRequests = errors.New("too many requests")
Expand Down
4 changes: 2 additions & 2 deletions service/index/index.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (
"github.com/pandodao/botastic/internal/milvus"
"github.com/pandodao/botastic/session"
"github.com/pandodao/tokenizer-go"
gogpt "github.com/sashabaranov/go-gpt3"
gogpt "github.com/sashabaranov/go-openai"
)

func NewService(ctx context.Context, gptHandler *gpt.Handler, indexes core.IndexStore, userz core.UserService) core.IndexService {
Expand Down Expand Up @@ -39,7 +39,7 @@ func (s *serviceImpl) createEmbeddingsWithLimit(ctx context.Context, req gogpt.E

resp, err := s.gptHandler.CreateEmbeddings(ctx, req)
if err == nil {
if err := s.userz.ConsumeCreditsByModel(ctx, userID, gogpt.AdaEmbeddingV2.String(), uint64(resp.Usage.TotalTokens)); err != nil {
if err := s.userz.ConsumeCreditsByModel(ctx, userID, gogpt.AdaEmbeddingV2.String(), int64(resp.Usage.PromptTokens), int64(resp.Usage.CompletionTokens)); err != nil {
log.Printf("ConsumeCredits error: %v\n", err)
}
}
Expand Down
47 changes: 29 additions & 18 deletions service/user/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,26 @@ import (
"strings"

"github.com/pandodao/botastic/core"
gogpt "github.com/sashabaranov/go-gpt3"
"github.com/shopspring/decimal"
"gorm.io/gorm"

"github.com/ethereum/go-ethereum/common"
"github.com/fox-one/mixin-sdk-go"
"github.com/fox-one/passport-go/mvm"
"github.com/fox-one/pkg/logger"
)

func New(
cfg Config,
client *mixin.Client,
users core.UserStore,
models core.ModelStore,
) *UserService {
return &UserService{
cfg: cfg,
client: client,
users: users,
models: models,
}
}

Expand All @@ -37,10 +39,11 @@ type UserService struct {
cfg Config
client *mixin.Client
users core.UserStore
models core.ModelStore
}

func (s *UserService) ReplaceStore(users core.UserStore) core.UserService {
return New(s.cfg, s.client, users)
return New(s.cfg, s.client, users, s.models)
}

func (s *UserService) LoginWithMixin(ctx context.Context, token, pubkey, lang string) (*core.User, error) {
Expand Down Expand Up @@ -148,27 +151,35 @@ func (s *UserService) Topup(ctx context.Context, user *core.User, amount decimal
return nil
}

func (s *UserService) ConsumeCreditsByModel(ctx context.Context, userID uint64, model string, tokenCount uint64) error {
price := decimal.Zero
switch model {
case gogpt.GPT3Dot5Turbo:
// $0.002 per 1000 tokens
price = decimal.NewFromFloat(0.000002)
case gogpt.GPT3TextDavinci003:
// $0.02 per 1000 tokens
price = decimal.NewFromFloat(0.00002)
case gogpt.AdaEmbeddingV2.String():
// $0.0004 per 1000 tokens
price = decimal.NewFromFloat(0.0000004)
default:
return core.ErrInvalidModel
func (s *UserService) ConsumeCreditsByModel(ctx context.Context, userID uint64, modelName string, promptTokenCount, completionTokenCount int64) error {
log := logger.FromContext(ctx).WithField("service", "user.ConsumeCreditsByModel")
model, err := s.models.GetModel(ctx, modelName)
if err != nil {
return err
}

credits := price.Mul(decimal.NewFromInt(int64(tokenCount)))
// price := decimal.Zero
// switch model {
// case gogpt.GPT3Dot5Turbo:
// // $0.002 per 1000 tokens
// price = decimal.NewFromFloat(0.000002)
// case gogpt.GPT3TextDavinci003:
// // $0.02 per 1000 tokens
// price = decimal.NewFromFloat(0.00002)
// case gogpt.AdaEmbeddingV2.String():
// // $0.0004 per 1000 tokens
// price = decimal.NewFromFloat(0.0000004)
// default:
// return core.ErrInvalidModel
// }

cost := model.CalculateTokenCost(promptTokenCount, completionTokenCount)
credits := cost
if s.cfg.ExtraRate > 0 {
credits = credits.Mul(decimal.NewFromFloat(1 + s.cfg.ExtraRate))
}
fmt.Printf("model: %v, price: $%s, token: %d, credits: $%s\n", model, price.StringFixed(8), tokenCount, credits.StringFixed(8))
log.Printf("model: %s:%s, cost: $%s, token: %d->%d, credits: $%s\n", model.Provider, model.ProviderModel,
cost.StringFixed(8), promptTokenCount, completionTokenCount, credits.StringFixed(8))
return s.ConsumeCredits(ctx, userID, credits)
}

Expand Down
68 changes: 68 additions & 0 deletions store/model/model.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package model

import (
"context"
_ "embed"
"fmt"

"github.com/pandodao/botastic/core"
"gopkg.in/yaml.v2"
)

type (
store struct {
modelMap map[string]*core.Model
}
)

//go:embed models.yaml
var modelsConfig string

func New() *store {
modelMap, err := LoadModels()
if err != nil {
panic(err)
}
return &store{
modelMap: modelMap,
}
}

func LoadModels() (map[string]*core.Model, error) {
models := []*core.Model{}
modelMap := make(map[string]*core.Model)

if err := yaml.Unmarshal([]byte(modelsConfig), &models); err != nil {
return nil, err
}

for _, m := range models {
key := fmt.Sprintf("%s:%s", m.Provider, m.ProviderModel)
modelMap[key] = m

if m.Provider == core.ModelProviderOpenAI {
switch m.ProviderModel {
case "gpt-4", "gpt-4-32k", "gpt-4-0314", "gpt-4-32k-0314", "gpt-3.5-turbo", "gpt-3.5-turbo-0301":
m.Props.IsOpenAIChatModel = true
case "text-davinci-003":
m.Props.IsOpenAICompletionModel = true
case "text-embedding-ada-002":
m.Props.IsOpenAIEmbeddingModel = true
}
}

// backward compatibility
if m.Provider == core.ModelProviderOpenAI {
modelMap[m.ProviderModel] = m
}
}

return modelMap, nil
}

func (s *store) GetModel(ctx context.Context, name string) (*core.Model, error) {
if model, ok := s.modelMap[name]; ok {
return model, nil
}
return nil, core.ErrInvalidModel
}
36 changes: 36 additions & 0 deletions store/model/models.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
- provider: openai
provider_model: gpt-4
prompt_price_usd: 0.00003 # $0.03 / 1K tokens
completion_price_usd: 0.00006 # $0.06 / 1K tokens
max_token: 8192
- provider: openai
provider_model: gpt-4-0314
prompt_price_usd: 0.00003 # $0.03 / 1K tokens
completion_price_usd: 0.00006 # $0.06 / 1K tokens
max_token: 8192
- provider: openai
provider_model: gpt-4-32k
prompt_price_usd: 0.00006 # $0.06 / 1K tokens
completion_price_usd: 0.00012 # $0.12 / 1K tokens
max_token: 32768
- provider: openai
provider_model: gpt-4-32k-0314
prompt_price_usd: 0.00006 # $0.06 / 1K tokens
completion_price_usd: 0.00012 # $0.12 / 1K tokens
max_token: 32768
- provider: openai
provider_model: gpt-3.5-turbo
price_usd: 0.000002 # $0.002 per 1000 tokens
max_token: 4096
- provider: openai
provider_model: gpt-3.5-turbo-0301
price_usd: 0.000002 # $0.002 per 1000 tokens
max_token: 4096
- provider: openai
provider_model: text-davinci-003
price_usd: 0.00002 # $0.02 per 1000 tokens
max_token: 4097
- provider: openai
provider_model: text-embedding-ada-002
price_usd: 0.0000004 # 0.0004 per 1000 tokens
max_token: 2049

0 comments on commit 71617ea

Please sign in to comment.