Skip to content

Commit

Permalink
Save middleware results to conv turn
Browse files Browse the repository at this point in the history
* change the general option of middlewares from `break` to `required`
  • Loading branch information
xwjdsh authored and lyricat committed Apr 11, 2023
1 parent 48182cb commit 1b16e4a
Show file tree
Hide file tree
Showing 10 changed files with 142 additions and 105 deletions.
14 changes: 7 additions & 7 deletions core/bot.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,16 @@ type (
}

MiddlewareProcessResult struct {
Opts any `json:"opts,omitempty"`
Name string `json:"name"`
Code int `json:"code"`
Result string `json:"result,omitempty"`
Err error `json:"err,omitempty"`
Break bool `json:"break,omitempty"`
Opts map[string]any `json:"opts,omitempty"`
Name string `json:"name"`
Code int `json:"code"`
Result string `json:"result,omitempty"`
Err error `json:"err,omitempty"`
Required bool `json:"required,omitempty"`
}

MiddlewareService interface {
ProcessByConfig(ctx context.Context, m MiddlewareConfig, input string) []*MiddlewareProcessResult
ProcessByConfig(ctx context.Context, m MiddlewareConfig, input string) MiddlewareResults
Process(ctx context.Context, m *Middleware, input string) *MiddlewareProcessResult
}
)
Expand Down
55 changes: 37 additions & 18 deletions core/conv.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,21 @@ func (b BotOverride) Value() (driver.Value, error) {
return json.Marshal(b)
}

type MiddlewareResults []*MiddlewareProcessResult

func (mr *MiddlewareResults) Scan(value interface{}) error {
data, ok := value.([]byte)
if !ok {
return errors.New(fmt.Sprint("type assertion to []byte failed:", value))
}

return json.Unmarshal(data, mr)
}

func (mr MiddlewareResults) Value() (driver.Value, error) {
return json.Marshal(mr)
}

type (
Conversation struct {
ID string `json:"id"`
Expand All @@ -47,21 +62,22 @@ type (
}

ConvTurn struct {
ID uint64 `json:"id"`
ConversationID string `json:"conversation_id"`
BotID uint64 `json:"bot_id"`
AppID uint64 `json:"app_id"`
UserID uint64 `json:"user_id"`
UserIdentity string `json:"user_identity"`
Request string `json:"request"`
Response string `json:"response"`
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
Status int `json:"status"`
BotOverride BotOverride `gorm:"type:jsonb" json:"bot_override"`
CreatedAt *time.Time `json:"created_at"`
UpdatedAt *time.Time `json:"updated_at"`
ID uint64 `json:"id"`
ConversationID string `json:"conversation_id"`
BotID uint64 `json:"bot_id"`
AppID uint64 `json:"app_id"`
UserID uint64 `json:"user_id"`
UserIdentity string `json:"user_identity"`
Request string `json:"request"`
Response string `json:"response"`
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
Status int `json:"status"`
BotOverride BotOverride `gorm:"type:jsonb" json:"bot_override"`
MiddlewareResults MiddlewareResults `gorm:"type:jsonb" json:"middleware_results,omitempty"`
CreatedAt *time.Time `json:"created_at"`
UpdatedAt *time.Time `json:"updated_at"`
}

ConversationStore interface {
Expand Down Expand Up @@ -98,7 +114,7 @@ type (
// "conversation_id", "bot_id", "app_id", "user_id",
// "user_identity",
// "request", "response", "status",
// "prompt_tokens", "completion_tokens", "total_tokens", "bot_override",
// "prompt_tokens", "completion_tokens", "total_tokens", "bot_override", "middleware_results",
// "created_at", "updated_at"
// FROM "conv_turns" WHERE
// "id" = @id
Expand All @@ -109,7 +125,7 @@ type (
// "conversation_id", "bot_id", "app_id", "user_id",
// "user_identity",
// "request", "response", "status",
// "prompt_tokens", "completion_tokens", "total_tokens", "bot_override",
// "prompt_tokens", "completion_tokens", "total_tokens", "bot_override", "middleware_results",
// "created_at", "updated_at"
// FROM "conv_turns"
// {{where}}
Expand All @@ -127,11 +143,14 @@ type (
// "completion_tokens"=@completionTokens,
// "total_tokens"=@totalTokens,
// "status"=@status,
// {{if mr != nil}}
// "middleware_results"=@mr,
// {{end}}
// "updated_at"=NOW()
// {{end}}
// WHERE
// "id"=@id
UpdateConvTurn(ctx context.Context, id uint64, response string, promptTokens, completionTokens, totalTokens int64, status int) error
UpdateConvTurn(ctx context.Context, id uint64, response string, promptTokens, completionTokens, totalTokens int64, status int, mr MiddlewareResults) error
}

ConversationService interface {
Expand Down
10 changes: 4 additions & 6 deletions service/middleware/botastic_search.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,17 @@ type botasticSearch struct {
}

type botasticSearchOptions struct {
*generalOptions
Limit int `json:"limit"`
AppID string `json:"app_id"`
Limit int
AppID string
}

func (m *botasticSearch) Name() string {
return core.MiddlewareBotasticSearch
}

func (m *botasticSearch) ValidateOptions(gopts *generalOptions, opts map[string]any) (any, error) {
func (m *botasticSearch) ValidateOptions(opts map[string]any) (any, error) {
options := &botasticSearchOptions{
generalOptions: gopts,
Limit: 3,
Limit: 3,
}

if val, ok := opts["limit"]; ok {
Expand Down
8 changes: 3 additions & 5 deletions service/middleware/ddg_search.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,12 @@ func (m *duckDuckGoSearch) Name() string {
}

type duckDuckGoSearchOptions struct {
*generalOptions
Limit int `json:"limit"`
Limit int
}

func (m *duckDuckGoSearch) ValidateOptions(gopts *generalOptions, opts map[string]any) (any, error) {
func (m *duckDuckGoSearch) ValidateOptions(opts map[string]any) (any, error) {
options := &duckDuckGoSearchOptions{
generalOptions: gopts,
Limit: 3,
Limit: 3,
}

if val, ok := opts["limit"]; ok {
Expand Down
9 changes: 3 additions & 6 deletions service/middleware/fetch.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,11 @@ func (m *fetch) Name() string {
}

type fetchOptions struct {
*generalOptions
URL string `json:"url"`
URL string
}

func (m *fetch) ValidateOptions(gopts *generalOptions, opts map[string]any) (any, error) {
options := &fetchOptions{
generalOptions: gopts,
}
func (m *fetch) ValidateOptions(opts map[string]any) (any, error) {
options := &fetchOptions{}

if val, ok := opts["url"]; ok {
v, ok := val.(string)
Expand Down
9 changes: 3 additions & 6 deletions service/middleware/intent_recognition.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,15 @@ import (
type intentRecognition struct{}

type intentRecognitionOptions struct {
*generalOptions
Intents []string `json:"intents"`
Intents []string
}

func (m *intentRecognition) Name() string {
return core.MiddlewareIntentRecognition
}

func (m *intentRecognition) ValidateOptions(gopts *generalOptions, opts map[string]any) (any, error) {
options := &intentRecognitionOptions{
generalOptions: gopts,
}
func (m *intentRecognition) ValidateOptions(opts map[string]any) (any, error) {
options := &intentRecognitionOptions{}

val, ok := opts["intents"]
if ok {
Expand Down
52 changes: 26 additions & 26 deletions service/middleware/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@ import (

type Middleware interface {
Name() string
ValidateOptions(gopts *generalOptions, opts map[string]any) (any, error)
ValidateOptions(opts map[string]any) (any, error)
Process(ctx context.Context, opts any, input string) (string, error)
}

type generalOptions struct {
Break bool `json:"break"`
Timeout time.Duration `json:"timeout"`
Required bool `json:"required"`
Timeout time.Duration `json:"timeout"`
}

func New(
Expand Down Expand Up @@ -60,12 +60,12 @@ type (
}
)

func (s *service) ProcessByConfig(ctx context.Context, mc core.MiddlewareConfig, input string) []*core.MiddlewareProcessResult {
func (s *service) ProcessByConfig(ctx context.Context, mc core.MiddlewareConfig, input string) core.MiddlewareResults {
var results []*core.MiddlewareProcessResult
for _, m := range mc.Items {
result := s.Process(ctx, m, input)
results = append(results, result)
if result.Break {
if result.Required {
break
}
}
Expand All @@ -76,30 +76,30 @@ func (s *service) Process(ctx context.Context, m *core.Middleware, input string)
gopts, err := parseGeneralOptions(ctx, m.Options)
if err != nil {
return &core.MiddlewareProcessResult{
Name: m.Name,
Code: core.MiddlewareProcessCodeInvalidOptions,
Err: err,
Break: true,
Name: m.Name,
Code: core.MiddlewareProcessCodeInvalidOptions,
Err: err,
Required: true,
}
}

middleware := s.middlewareMap[m.Name]
if middleware == nil {
return &core.MiddlewareProcessResult{
Name: m.Name,
Code: core.MiddlewareProcessCodeUnknown,
Err: errors.New("middleware not found"),
Break: gopts.Break,
Name: m.Name,
Code: core.MiddlewareProcessCodeUnknown,
Err: errors.New("middleware not found"),
Required: gopts.Required,
}
}

opts, err := middleware.ValidateOptions(gopts, m.Options)
opts, err := middleware.ValidateOptions(m.Options)
if err != nil {
return &core.MiddlewareProcessResult{
Name: m.Name,
Code: core.MiddlewareProcessCodeInvalidOptions,
Err: err,
Break: gopts.Break,
Name: m.Name,
Code: core.MiddlewareProcessCodeInvalidOptions,
Err: err,
Required: gopts.Required,
}
}

Expand All @@ -117,15 +117,15 @@ func (s *service) Process(ctx context.Context, m *core.Middleware, input string)
}

return &core.MiddlewareProcessResult{
Name: m.Name,
Code: code,
Err: err,
Break: gopts.Break,
Name: m.Name,
Code: code,
Err: err,
Required: gopts.Required,
}
}

return &core.MiddlewareProcessResult{
Opts: opts,
Opts: m.Options,
Name: m.Name,
Code: core.MiddlewareProcessCodeOK,
Result: result,
Expand All @@ -135,12 +135,12 @@ func (s *service) Process(ctx context.Context, m *core.Middleware, input string)
func parseGeneralOptions(ctx context.Context, opts map[string]any) (*generalOptions, error) {
generalOptions := &generalOptions{}

if val, ok := opts["break"]; ok {
if val, ok := opts["required"]; ok {
b, ok := val.(bool)
if !ok {
return nil, fmt.Errorf("break should be bool: %v", val)
return nil, fmt.Errorf("required should be bool: %v", val)
}
generalOptions.Break = b
generalOptions.Required = b
}

if val, ok := opts["timeout"]; ok {
Expand Down

0 comments on commit 1b16e4a

Please sign in to comment.