Skip to content

Commit

Permalink
Support custom models
Browse files Browse the repository at this point in the history
  • Loading branch information
xwjdsh authored and lyricat committed Mar 30, 2023
1 parent 1545152 commit 347951e
Show file tree
Hide file tree
Showing 18 changed files with 508 additions and 163 deletions.
2 changes: 1 addition & 1 deletion cmd/httpd/httpd.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ func NewCmdHttpd() *cobra.Command {
userz := userServ.New(userServ.Config{
ExtraRate: cfg.Sys.ExtraRate,
InitUserCredits: cfg.Sys.InitUserCredits,
}, client, users, models)
}, client, users)
indexService := indexServ.NewService(ctx, gptHandler, indexes, userz)
appz := appServ.New(appServ.Config{
SecretKey: cfg.Sys.SecretKey,
Expand Down
108 changes: 108 additions & 0 deletions cmd/model/model.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
package model

import (
"encoding/json"
"fmt"

"github.com/pandodao/botastic/config"
"github.com/pandodao/botastic/core"
"github.com/pandodao/botastic/store"
"github.com/pandodao/botastic/store/model"
"github.com/spf13/cobra"
)

func NewCmdModel() *cobra.Command {
cmd := &cobra.Command{
Use: "model",
Short: "model commands",
}

cmd.AddCommand(NewCmdModelCreate())
cmd.AddCommand(NewCmdModelList())
return cmd
}

func NewCmdModelList() *cobra.Command {
cmd := &cobra.Command{
Use: "list",
Short: "list models",
Run: func(cmd *cobra.Command, args []string) {
ctx := cmd.Context()

cfg := config.C()
h := store.MustInit(store.Config{
Driver: cfg.DB.Driver,
DSN: cfg.DB.DSN,
})
models := model.New(h)

ms, err := models.GetModelsByFunction(ctx, "")
if err != nil {
cmd.PrintErr(err.Error())
return
}
for _, item := range ms {
cmd.Printf("%+v\n", item)
}
},
}

return cmd
}

func NewCmdModelCreate() *cobra.Command {
var data string
cmd := &cobra.Command{
Use: "create",
Short: "create custom model",
RunE: func(cmd *cobra.Command, args []string) error {
ctx := cmd.Context()

cfg := config.C()
h := store.MustInit(store.Config{
Driver: cfg.DB.Driver,
DSN: cfg.DB.DSN,
})
models := model.New(h)

m := &core.Model{}
if err := json.Unmarshal([]byte(data), m); err != nil {
return fmt.Errorf("invalid model data: %w", err)
}
if m.ProviderModel == "" {
return fmt.Errorf("provider model is empty")
}

cc, err := m.UnmarshalCustomConfig()
if err != nil {
return fmt.Errorf("unmarshal custom config error: %w", err)
}
if cc.Request.URL == "" || cc.Request.Method == "" || len(cc.Request.Data) == 0 {
return fmt.Errorf("request of custom config is empty")
}

switch m.Function {
case core.ModelFunctionChat, core.ModelFunctionEmbedding:
default:
return fmt.Errorf("invalid model function: %s", m.Function)
}

m.Provider = core.ModelProviderCustom
if err := models.CreateModel(ctx, m); err != nil {
return fmt.Errorf("create model error: %w", err)
}

mm, err := models.GetModel(ctx, m.Name())
if err != nil {
return err
}

data, _ := json.Marshal(mm)
cmd.Printf("%s\n", string(data))
return nil
},
}

cmd.Flags().StringVar(&data, "data", "", "model data in JSON format")
return cmd
}
2 changes: 2 additions & 0 deletions cmd/root/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/pandodao/botastic/cmd/gen"
"github.com/pandodao/botastic/cmd/httpd"
"github.com/pandodao/botastic/cmd/migrate"
"github.com/pandodao/botastic/cmd/model"
"github.com/pandodao/botastic/cmdutil"
"github.com/pandodao/botastic/config"
"github.com/pandodao/botastic/session"
Expand Down Expand Up @@ -69,6 +70,7 @@ func NewCmdRoot(version string) *cobra.Command {
cmd.AddCommand(migrate.NewCmdMigrate())
cmd.AddCommand(gen.NewCmdGen())
cmd.AddCommand(app.NewCmdApp())
cmd.AddCommand(model.NewCmdModel())

return cmd
}
23 changes: 20 additions & 3 deletions core/bot.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,20 @@ import (

type JSONB json.RawMessage

func (j JSONB) MarshalJSON() ([]byte, error) {
return json.RawMessage(j).MarshalJSON()
}

func (j *JSONB) UnmarshalJSON(data []byte) error {
var v json.RawMessage
if err := json.Unmarshal(data, &v); err != nil {
return err
}

*j = JSONB(v)
return nil
}

// implement sql.Scanner interface, Scan value into Jsonb
func (j *JSONB) Scan(value interface{}) error {
bytes, ok := value.([]byte)
Expand Down Expand Up @@ -165,23 +179,26 @@ func (t *Bot) DecodeMiddlewares() error {
return json.Unmarshal(val.([]byte), &t.Middlewares)
}

func (t *Bot) GetPrompt(conv *Conversation, question string) string {
func (t *Bot) GetPrompt(conv *Conversation, question string, additionData map[string]any) string {
var buf bytes.Buffer
data := map[string]interface{}{
"LangHint": conv.LangHint(),
"History": conv.HistoryToText(),
}
for k, v := range additionData {
data[k] = v
}

if t.PromptTpl == nil {
t.PromptTpl = template.Must(template.New(fmt.Sprintf("%d-prompt-tmpl", t.ID)).Parse(t.Prompt))
}
t.PromptTpl.Execute(&buf, data)

str := buf.String()

return strings.TrimSpace(str) + "\n"
}

func (t *Bot) GetChatMessages(conv *Conversation, additionData map[string]interface{}) []gogpt.ChatCompletionMessage {
func (t *Bot) GetChatMessages(conv *Conversation, additionData map[string]any) []gogpt.ChatCompletionMessage {
var buf bytes.Buffer
data := map[string]interface{}{
"LangHint": conv.LangHint(),
Expand Down
2 changes: 1 addition & 1 deletion core/conv.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ type (
// {{end}}
// WHERE
// "id"=@id
UpdateConvTurn(ctx context.Context, id uint64, response string, totalTokens int, status int) error
UpdateConvTurn(ctx context.Context, id uint64, response string, totalTokens int64, status int) error
}

ConversationService interface {
Expand Down
80 changes: 71 additions & 9 deletions core/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,34 @@ package core

import (
"context"
"encoding/json"
"fmt"
"time"

"github.com/shopspring/decimal"
)

const (
ModelProviderOpenAI = "openai"
ModelProviderCustom = "custom"

ModelFunctionChat = "chat"
ModelFunctionEmbedding = "embedding"
)

type (
CustomConfig struct {
Request struct {
URL string `json:"url"`
Method string `json:"method"`
Headers map[string]string `json:"headers"`
Data map[string]any `json:"data"`
} `json:"request"`
Response struct {
Path string `json:"path"`
} `json:"response"`
}

Model struct {
ID uint64 `json:"id"`
Provider string `json:"provider"`
Expand All @@ -20,14 +38,11 @@ type (
PromptPriceUSD decimal.Decimal `json:"prompt_price_usd"`
CompletionPriceUSD decimal.Decimal `json:"completion_price_usd"`
PriceUSD decimal.Decimal `json:"price_usd"`
CreatedAt time.Time `json:"-"`
DeletedAt *time.Time `json:"-"`
CustomConfig JSONB `gorm:"type:jsonb;" json:"custom_config,omitempty"`
Function string `json:"function"`

Props struct {
IsOpenAIChatModel bool `yaml:"is_openai_chat_model"`
IsOpenAICompletionModel bool `yaml:"is_openai_completion_model"`
IsOpenAIEmbeddingModel bool `yaml:"is_openai_embedding_model"`
} `gorm:"-" json:"-"`
CreatedAt time.Time `json:"-"`
DeletedAt *time.Time `json:"-"`
}

ModelStore interface {
Expand All @@ -40,11 +55,24 @@ type (
// SELECT *
// FROM @@table WHERE
// "deleted_at" IS NULL
GetModels(ctx context.Context) ([]*Model, error)
// {{if f !=""}}
// AND function=@f
// {{end}}
GetModelsByFunction(ctx context.Context, f string) ([]*Model, error)

// INSERT INTO @@table
// ("provider", "provider_model", "max_token", "prompt_price_usd", "completion_price_usd", "price_usd", "custom_config", "function", "created_at")
// VALUES
// (@model.Provider, @model.ProviderModel, @model.MaxToken, @model.PromptPriceUSD, @model.CompletionPriceUSD, @model.PriceUSD, @model.CustomConfig, @model.Function, NOW())
CreateModel(ctx context.Context, model *Model) error
}
)

func (m *Model) CalculateTokenCost(promptCount, completionCount int64) decimal.Decimal {
func (m Model) Name() string {
return fmt.Sprintf("%s:%s", m.Provider, m.ProviderModel)
}

func (m Model) CalculateTokenCost(promptCount, completionCount int64) decimal.Decimal {
pc := decimal.NewFromInt(promptCount)
cc := decimal.NewFromInt(completionCount)

Expand All @@ -56,3 +84,37 @@ func (m *Model) CalculateTokenCost(promptCount, completionCount int64) decimal.D
}
return decimal.Zero
}

func (m Model) IsOpenAIChatModel() bool {
if m.Provider != ModelProviderOpenAI {
return false
}

switch m.ProviderModel {
case "gpt-4", "gpt-4-32k", "gpt-3.5-turbo":
return true
}

return false
}

func (m Model) IsOpenAICompletionModel() bool {
if m.Provider != ModelProviderOpenAI {
return false
}
switch m.ProviderModel {
case "text-davinci-003":
return true
}

return false
}

func (m Model) UnmarshalCustomConfig() (*CustomConfig, error) {
r := &CustomConfig{}
if err := json.Unmarshal(m.CustomConfig, r); err != nil {
return nil, err
}

return r, nil
}
3 changes: 1 addition & 2 deletions core/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,7 @@ type (
LoginWithMixin(ctx context.Context, token, pubkey, lang string) (*User, error)
Topup(ctx context.Context, user *User, amount decimal.Decimal) error
ConsumeCredits(ctx context.Context, userID uint64, amount decimal.Decimal) error
// ConsumeCreditsByModel(ctx context.Context, userID uint64, model string, amount uint64) error
ConsumeCreditsByModel(ctx context.Context, userID uint64, model string, promptTokenCount, completionTokenCount int64) error
ConsumeCreditsByModel(ctx context.Context, userID uint64, model Model, promptTokenCount, completionTokenCount int64) error
ReplaceStore(store UserStore) UserService
}
)
Expand Down
5 changes: 4 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ require (
github.com/sashabaranov/go-openai v1.5.7
github.com/sirupsen/logrus v1.9.0
github.com/spf13/cobra v1.6.1
github.com/tidwall/gjson v1.14.4
golang.org/x/sync v0.1.0
gopkg.in/yaml.v3 v3.0.1
gorm.io/driver/postgres v1.4.5
Expand Down Expand Up @@ -70,6 +71,8 @@ require (
github.com/spf13/cast v1.5.0 // indirect
github.com/spf13/jwalterweatherman v1.1.0 // indirect
github.com/subosito/gotenv v1.4.2 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.0 // indirect
github.com/twitchtv/twirp v8.1.2+incompatible // indirect
github.com/vmihailenco/tagparser v0.1.2 // indirect
google.golang.org/appengine v1.6.7 // indirect
Expand Down Expand Up @@ -123,7 +126,7 @@ require (
golang.org/x/text v0.8.0 // indirect
golang.org/x/tools v0.6.0 // indirect
google.golang.org/protobuf v1.28.1 // indirect
gopkg.in/yaml.v2 v2.4.0
gopkg.in/yaml.v2 v2.4.0 // indirect
gorm.io/datatypes v1.1.1-0.20230130040222-c43177d3cf8c // indirect
gorm.io/hints v1.1.1 // indirect
)
6 changes: 6 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -573,6 +573,12 @@ github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ
github.com/subosito/gotenv v1.4.2 h1:X1TuBLAMDFbaTAChgCBLu3DU3UPyELpnF2jjJ2cz/S8=
github.com/subosito/gotenv v1.4.2/go.mod h1:ayKnFf/c6rvx/2iiLrJUk1e6plDbT3edrFNGqEflhK0=
github.com/syndtr/goleveldb v1.0.1-0.20210819022825-2ae1ddf74ef7 h1:epCh84lMvA70Z7CTTCmYQn2CKbY8j86K7/FAIr141uY=
github.com/tidwall/gjson v1.14.4 h1:uo0p8EbA09J7RQaflQ1aBRffTR7xedD2bcIVSYxLnkM=
github.com/tidwall/gjson v1.14.4/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tklauser/go-sysconf v0.3.11 h1:89WgdJhk5SNwJfu+GKyYveZ4IaJ7xAkecBo+KdJV0CM=
github.com/tklauser/numcpus v0.2.2 h1:oyhllyrScuYI6g+h/zUvNXNp1wy7x8qQy3t/piefldA=
github.com/twitchtv/twirp v8.1.2+incompatible h1:0O6TfzZW09ZP5r+ORA90XQEE3PTgA6C7MBbl2KxvVgE=
Expand Down
7 changes: 6 additions & 1 deletion handler/model/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,17 @@ import (

func GetModels(models core.ModelStore) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ms, err := models.GetModels(r.Context())
ms, err := models.GetModelsByFunction(r.Context(), r.URL.Query().Get("function"))
if err != nil {
render.Error(w, http.StatusInternalServerError, err)
return
}

for _, m := range ms {
// do not expose custom config, it may contain sensitive information
m.CustomConfig = nil
}

render.JSON(w, ms)
}
}
1 change: 1 addition & 0 deletions handler/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ func (s Server) HandleRest() http.Handler {

r.Route("/models", func(r chi.Router) {
r.Get("/", model.GetModels(s.models))
r.Post("/", model.GetModels(s.models))
})

r.Route("/conversations", func(r chi.Router) {
Expand Down

0 comments on commit 347951e

Please sign in to comment.