Skip to content

Commit

Permalink
Support for dynamically setting up middlewares in conversations
Browse files Browse the repository at this point in the history
  • Loading branch information
xwjdsh authored and lyricat committed Apr 3, 2023
1 parent c1d22c9 commit ab5d32c
Show file tree
Hide file tree
Showing 10 changed files with 92 additions and 203 deletions.
6 changes: 1 addition & 5 deletions cmd/model/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,7 @@ func NewCmdModelCreate() *cobra.Command {
return fmt.Errorf("provider model is empty")
}

cc, err := m.UnmarshalCustomConfig()
if err != nil {
return fmt.Errorf("unmarshal custom config error: %w", err)
}
if cc.Request.URL == "" || cc.Request.Method == "" || len(cc.Request.Data) == 0 {
if m.CustomConfig.Request.URL == "" || m.CustomConfig.Request.Method == "" || len(m.CustomConfig.Request.Data) == 0 {
return fmt.Errorf("request of custom config is empty")
}

Expand Down
99 changes: 45 additions & 54 deletions core/bot.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,65 +14,69 @@ import (
gogpt "github.com/sashabaranov/go-openai"
)

type JSONB json.RawMessage
const (
MiddlewareBotasticSearch = "botastic-search"
MiddlewareDuckduckgoSearch = "duckduckgo-search"
MiddlewareIntentRecognition = "intent-recognition"
)

func (j JSONB) MarshalJSON() ([]byte, error) {
return json.RawMessage(j).MarshalJSON()
}
const MiddlewareProcessCodeUnknown = -1

func (j *JSONB) UnmarshalJSON(data []byte) error {
var v json.RawMessage
if err := json.Unmarshal(data, &v); err != nil {
return err
const (
MiddlewareProcessCodeOK = iota
)

type (
Middleware struct {
Name string `json:"name"`
Options map[string]any `json:"options,omitempty"`
}

*j = JSONB(v)
return nil
}
MiddlewareConfig struct {
Items []*Middleware `json:"items"`
}

// implement sql.Scanner interface, Scan value into Jsonb
func (j *JSONB) Scan(value interface{}) error {
bytes, ok := value.([]byte)
if !ok {
return errors.New(fmt.Sprint("type assertion to []byte failed:", value))
MiddlewareProcessResult struct {
Name string `json:"name"`
Code uint64 `json:"code"`
Result string `json:"result"`
}

result := json.RawMessage{}
err := json.Unmarshal(bytes, &result)
if err != nil {
return err
MiddlewareService interface {
Process(ctx context.Context, m *Middleware, input string) (*MiddlewareProcessResult, error)
}
*j = JSONB(result)
return err
)

func (a MiddlewareConfig) Value() (driver.Value, error) {
return json.Marshal(a)
}

// implement driver.Valuer interface, Value return json value
func (j JSONB) Value() (driver.Value, error) {
if len(j) == 0 {
return nil, nil
func (a *MiddlewareConfig) Scan(value interface{}) error {
b, ok := value.([]byte)
if !ok {
return errors.New("type assertion to []byte failed")
}
return json.RawMessage(j).MarshalJSON()
return json.Unmarshal(b, a)
}

type (
Bot struct {
ID uint64 `json:"id"`
Name string `json:"name"`
UserID uint64 `json:"user_id"`
Prompt string `json:"prompt"`
Model string `json:"model"`
MaxTurnCount int `json:"max_turn_count"`
ContextTurnCount int `json:"context_turn_count"`
Temperature float32 `json:"temperature"`
MiddlewareJson JSONB `gorm:"type:jsonb" json:"-"`
Public bool `json:"public"`
ID uint64 `json:"id"`
Name string `json:"name"`
UserID uint64 `json:"user_id"`
Prompt string `json:"prompt"`
Model string `json:"model"`
MaxTurnCount int `json:"max_turn_count"`
ContextTurnCount int `json:"context_turn_count"`
Temperature float32 `json:"temperature"`
MiddlewareJson MiddlewareConfig `gorm:"type:jsonb" json:"middlewares"`
Public bool `json:"public"`

CreatedAt *time.Time `json:"created_at"`
UpdatedAt *time.Time `json:"updated_at"`
DeletedAt *time.Time `json:"deleted_at"`

PromptTpl *template.Template `gorm:"-" json:"-"`
Middlewares MiddlewareConfig `gorm:"-" json:"middlewares"`
PromptTpl *template.Template `gorm:"-" json:"-"`
}

BotStore interface {
Expand Down Expand Up @@ -120,7 +124,7 @@ type (
temperature float32,
maxTurnCount,
contextTurnCount int,
middlewareJson JSONB, public bool,
middlewareJson MiddlewareConfig, public bool,
) (uint64, error)

// UPDATE @@table
Expand All @@ -142,7 +146,7 @@ type (
temperature float32,
maxTurnCount,
contextTurnCount int,
middlewareJson JSONB,
middlewareJson MiddlewareConfig,
public bool,
) error

Expand All @@ -166,19 +170,6 @@ type (
}
)

func (t *Bot) DecodeMiddlewares() error {
if t.MiddlewareJson == nil {
return nil
}

val, err := t.MiddlewareJson.Value()
if err != nil {
return err
}

return json.Unmarshal(val.([]byte), &t.Middlewares)
}

func (t *Bot) GetPrompt(conv *Conversation, question string, additionData map[string]any) string {
var buf bytes.Buffer
data := map[string]interface{}{
Expand Down
3 changes: 2 additions & 1 deletion core/conv.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ const (
)

type BotOverride struct {
Temperature *float32 `json:"temperature,omitempty"`
Temperature *float32 `json:"temperature,omitempty"`
Middlewares *MiddlewareConfig `json:"middlewares,omitempty"`
}

func (b *BotOverride) Scan(value interface{}) error {
Expand Down
67 changes: 0 additions & 67 deletions core/middleware.go

This file was deleted.

48 changes: 27 additions & 21 deletions core/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ package core

import (
"context"
"database/sql/driver"
"encoding/json"
"errors"
"fmt"
"time"

Expand All @@ -17,19 +19,32 @@ const (
ModelFunctionEmbedding = "embedding"
)

type (
CustomConfig struct {
Request struct {
URL string `json:"url"`
Method string `json:"method"`
Headers map[string]string `json:"headers"`
Data map[string]any `json:"data"`
} `json:"request"`
Response struct {
Path string `json:"path"`
} `json:"response"`
type CustomConfig struct {
Request *struct {
URL string `json:"url"`
Method string `json:"method"`
Headers map[string]string `json:"headers"`
Data map[string]any `json:"data"`
} `json:"request,omitempty"`
Response *struct {
Path string `json:"path"`
} `json:"response,omitempty"`
}

func (c *CustomConfig) Scan(value interface{}) error {
data, ok := value.([]byte)
if !ok {
return errors.New(fmt.Sprint("type assertion to []byte failed:", value))
}

return json.Unmarshal(data, c)
}

func (b CustomConfig) Value() (driver.Value, error) {
return json.Marshal(b)
}

type (
Model struct {
ID uint64 `json:"id"`
Provider string `json:"provider"`
Expand All @@ -38,7 +53,7 @@ type (
PromptPriceUSD decimal.Decimal `json:"prompt_price_usd"`
CompletionPriceUSD decimal.Decimal `json:"completion_price_usd"`
PriceUSD decimal.Decimal `json:"price_usd"`
CustomConfig JSONB `gorm:"type:jsonb;" json:"custom_config,omitempty"`
CustomConfig CustomConfig `gorm:"type:jsonb;" json:"custom_config,omitempty"`
Function string `json:"function"`

CreatedAt time.Time `json:"-"`
Expand Down Expand Up @@ -109,12 +124,3 @@ func (m Model) IsOpenAICompletionModel() bool {

return false
}

func (m Model) UnmarshalCustomConfig() (*CustomConfig, error) {
r := &CustomConfig{}
if err := json.Unmarshal(m.CustomConfig, r); err != nil {
return nil, err
}

return r, nil
}
2 changes: 1 addition & 1 deletion handler/bot/bot.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func (body *CreateOrUpdateBotPayload) Formalize(defaultValue *core.Bot) error {
body.ContextTurnCount = defaultValue.ContextTurnCount
}
if body.Middlewares.Items == nil {
body.Middlewares.Items = defaultValue.Middlewares.Items
body.Middlewares.Items = defaultValue.MiddlewareJson.Items
}
} else {
if body.Temperature <= 0 {
Expand Down
2 changes: 1 addition & 1 deletion handler/model/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ func GetModels(models core.ModelStore) http.HandlerFunc {

for _, m := range ms {
// do not expose custom config, it may contain sensitive information
m.CustomConfig = nil
m.CustomConfig = core.CustomConfig{}
}

render.JSON(w, ms)
Expand Down

0 comments on commit ab5d32c

Please sign in to comment.