Skip to content

Commit

Permalink
Add boundary prompt to bots
Browse files Browse the repository at this point in the history
  • Loading branch information
xwjdsh authored and lyricat committed Apr 18, 2023
1 parent 21b8cde commit 30f9834
Show file tree
Hide file tree
Showing 6 changed files with 188 additions and 157 deletions.
138 changes: 73 additions & 65 deletions core/bot.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,90 +76,68 @@ type (
Name string `json:"name"`
UserID uint64 `json:"user_id"`
Prompt string `json:"prompt"`
BoundaryPrompt string `json:"boundary_prompt"`
Model string `json:"model"`
MaxTurnCount int `json:"max_turn_count"`
ContextTurnCount int `json:"context_turn_count"`
Temperature float32 `json:"temperature"`
MiddlewareJson MiddlewareConfig `gorm:"type:jsonb" json:"middlewares"`
Public bool `json:"public"`

CreatedAt *time.Time `json:"created_at"`
UpdatedAt *time.Time `json:"updated_at"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
DeletedAt *time.Time `json:"deleted_at"`

PromptTpl *template.Template `gorm:"-" json:"-"`
PromptTpl *template.Template `gorm:"-" json:"-"`
BoundaryPromptTpl *template.Template `gorm:"-" json:"-"`
}

BotStore interface {
// SELECT "id",
// "user_id", "name", "model", "prompt", "temperature",
// "max_turn_count", "context_turn_count",
// "middleware_json", "public",
// "created_at", "updated_at"
// SELECT *
// FROM @@table WHERE
// "id"=@id AND "deleted_at" IS NULL
// LIMIT 1
GetBot(ctx context.Context, id uint64) (*Bot, error)

// SELECT "id",
// "user_id", "name", "model", "prompt", "temperature",
// "max_turn_count", "context_turn_count",
// "middleware_json", "public",
// "created_at", "updated_at"
// SELECT *
// FROM @@table WHERE
// "user_id"=@userID AND "deleted_at" IS NULL
GetBotsByUserID(ctx context.Context, userID uint64) ([]*Bot, error)

// SELECT "id",
// "user_id", "name", "model", "prompt", "temperature",
// "max_turn_count", "context_turn_count",
// "middleware_json", "public",
// "created_at", "updated_at"
// SELECT *
// FROM @@table WHERE
// "public"='t' AND "deleted_at" IS NULL
GetPublicBots(ctx context.Context) ([]*Bot, error)

// INSERT INTO @@table
// ("user_id", "name", "model", "prompt", "temperature",
// ("user_id", "name", "model", "prompt", "boundary_prompt", "temperature",
// "max_turn_count", "context_turn_count",
// "middleware_json", "public",
// "created_at", "updated_at")
// VALUES
// (@userID, @name, @model, @prompt, @temperature,
// @maxTurnCount, @contextTurnCount,
// @middlewareJson, @public,
// (@bot.UserID, @bot.Name, @bot.Model, @bot.Prompt, @bot.BoundaryPrompt, @bot.Temperature,
// @bot.MaxTurnCount, @bot.ContextTurnCount,
// @bot.MiddlewareJson, @bot.Public,
// NOW(), NOW())
// RETURNING "id"
CreateBot(ctx context.Context, userID uint64,
name, model, prompt string,
temperature float32,
maxTurnCount,
contextTurnCount int,
middlewareJson MiddlewareConfig, public bool,
) (uint64, error)
CreateBot(ctx context.Context, bot *Bot) (uint64, error)

// UPDATE @@table
// {{set}}
// "name"=@name,
// "model"=@model,
// "prompt"=@prompt,
// "temperature"=@temperature,
// "max_turn_count"=@maxTurnCount,
// "context_turn_count"=@contextTurnCount,
// "middleware_json"=@middlewareJson,
// "public"=@public,
// "name"=@bot.Name,
// "model"=@bot.Model,
// "prompt"=@bot.Prompt,
// "boundary_prompt"=@bot.BoundaryPrompt,
// "temperature"=@bot.Temperature,
// "max_turn_count"=@bot.MaxTurnCount,
// "context_turn_count"=@bot.ContextTurnCount,
// "middleware_json"=@bot.MiddlewareJson,
// "public"=@bot.Public,
// "updated_at"=NOW()
// {{end}}
// WHERE
// "id"=@id AND "deleted_at" is NULL
UpdateBot(ctx context.Context, id uint64,
name, model, prompt string,
temperature float32,
maxTurnCount,
contextTurnCount int,
middlewareJson MiddlewareConfig,
public bool,
) error
// "id"=@bot.ID AND "deleted_at" is NULL
UpdateBot(ctx context.Context, bot *Bot) error

// UPDATE @@table
// {{set}}
Expand All @@ -174,30 +152,47 @@ type (
GetBot(ctx context.Context, id uint64) (*Bot, error)
GetPublicBots(ctx context.Context) ([]*Bot, error)
GetBotsByUserID(ctx context.Context, userID uint64) ([]*Bot, error)
CreateBot(ctx context.Context, userID uint64, name, model, prompt string, temperature float32, maxTurnCount, contextTurnCount int, middlewares MiddlewareConfig, public bool) (*Bot, error)
UpdateBot(ctx context.Context, id uint64, name, model, prompt string, temperature float32, maxTurnCount, contextTurnCount int, middlewares MiddlewareConfig, public bool) error
CreateBot(ctx context.Context, b *Bot) error
UpdateBot(ctx context.Context, b *Bot) error
DeleteBot(ctx context.Context, id uint64) error
ReplaceStore(store BotStore) BotService
}
)

func (t *Bot) GetPrompt(conv *Conversation, question string, additionData map[string]any) string {
func (t *Bot) GetRequestContent(conv *Conversation, question string, additionData map[string]any) string {

var buf bytes.Buffer
data := map[string]interface{}{
"LangHint": conv.LangHint(),
"History": conv.HistoryToText(),
}
for k, v := range additionData {
data[k] = v
}

if t.PromptTpl == nil {
t.PromptTpl = template.Must(template.New(fmt.Sprintf("%d-prompt-tmpl", t.ID)).Parse(t.Prompt))
var result string
if t.Prompt != "" {
if t.PromptTpl == nil {
t.PromptTpl = template.Must(template.New(fmt.Sprintf("%d-prompt-tmpl", t.ID)).Parse(t.Prompt))
}
t.PromptTpl.Execute(&buf, data)

str := buf.String()
result = strings.TrimSpace(str) + "\n"
}

result += conv.HistoryToText()

if t.BoundaryPrompt != "" {
if t.BoundaryPromptTpl == nil {
t.BoundaryPromptTpl = template.Must(template.New(fmt.Sprintf("%d-boundary-prompt-tmpl", t.ID)).Parse(t.BoundaryPrompt))
}
t.BoundaryPromptTpl.Execute(&buf, data)

str := buf.String()
result += "\n" + strings.TrimSpace(str)
}
t.PromptTpl.Execute(&buf, data)

str := buf.String()
return strings.TrimSpace(str) + "\n"
return result
}

func (t *Bot) GetChatMessages(conv *Conversation, additionData map[string]any) []gogpt.ChatCompletionMessage {
Expand All @@ -210,19 +205,18 @@ func (t *Bot) GetChatMessages(conv *Conversation, additionData map[string]any) [
data[k] = v
}

if t.PromptTpl == nil {
t.PromptTpl = template.Must(template.New(fmt.Sprintf("%d-prompt-tmpl", t.ID)).Parse(t.Prompt))
}

t.PromptTpl.Execute(&buf, data)

str := buf.String()
result := []gogpt.ChatCompletionMessage{}
if t.Prompt != "" {
if t.PromptTpl == nil {
t.PromptTpl = template.Must(template.New(fmt.Sprintf("%d-prompt-tmpl", t.ID)).Parse(t.Prompt))
}

result := []gogpt.ChatCompletionMessage{
{
t.PromptTpl.Execute(&buf, data)
str := buf.String()
result = append(result, gogpt.ChatCompletionMessage{
Role: "system",
Content: str,
},
})
}

history := conv.History
Expand All @@ -243,5 +237,19 @@ func (t *Bot) GetChatMessages(conv *Conversation, additionData map[string]any) [
}
}

if t.BoundaryPrompt != "" {
if t.BoundaryPromptTpl == nil {
t.BoundaryPromptTpl = template.Must(template.New(fmt.Sprintf("%d-boundary-prompt-tmpl", t.ID)).Parse(t.BoundaryPrompt))
}

t.BoundaryPromptTpl.Execute(&buf, data)
str := buf.String()
result = append(result, gogpt.ChatCompletionMessage{
Role: "system",
Content: str,
})

}

return result
}
33 changes: 29 additions & 4 deletions handler/bot/bot.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ type CreateOrUpdateBotPayload struct {
Name string `json:"name"`
Model string `json:"model"`
Prompt string `json:"prompt"`
BoundaryPrompt string `json:"boundary_prompt"`
Temperature float32 `json:"temperature"`
MaxTurnCount int `json:"max_turn_count"`
ContextTurnCount int `json:"context_turn_count"`
Expand Down Expand Up @@ -169,13 +170,26 @@ func UpdateBot(botz core.BotService) http.HandlerFunc {
return
}

err = botz.UpdateBot(ctx, botID, body.Name, body.Model, body.Prompt, body.Temperature, body.MaxTurnCount, body.ContextTurnCount, body.Middlewares, false)
b := &core.Bot{
ID: bot.ID,
Name: body.Name,
UserID: bot.UserID,
Prompt: body.Prompt,
BoundaryPrompt: body.BoundaryPrompt,
Model: body.Model,
MaxTurnCount: body.MaxTurnCount,
ContextTurnCount: body.ContextTurnCount,
Temperature: body.Temperature,
MiddlewareJson: body.Middlewares,
Public: false,
}
err = botz.UpdateBot(ctx, b)
if err != nil {
render.Error(w, http.StatusInternalServerError, err)
return
}

render.JSON(w, bot)
render.JSON(w, b)
}
}

Expand Down Expand Up @@ -206,8 +220,19 @@ func CreateBot(botz core.BotService, models core.ModelStore, botPerUserLimit int
return
}

bot, err := botz.CreateBot(ctx, user.ID, body.Name, body.Model, body.Prompt, body.Temperature, body.MaxTurnCount, body.ContextTurnCount, body.Middlewares, false)
if err != nil {
bot := &core.Bot{
UserID: user.ID,
Name: body.Name,
Model: body.Model,
Prompt: body.Prompt,
BoundaryPrompt: body.BoundaryPrompt,
Temperature: body.Temperature,
MaxTurnCount: body.MaxTurnCount,
ContextTurnCount: body.ContextTurnCount,
MiddlewareJson: body.Middlewares,
Public: false,
}
if err := botz.CreateBot(ctx, bot); err != nil {
statusCode := http.StatusInternalServerError
if errors.Is(err, core.ErrInvalidModel) {
statusCode = http.StatusBadRequest
Expand Down
46 changes: 22 additions & 24 deletions service/bot/bot.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,58 +98,56 @@ func (s *service) GetBotsByUserID(ctx context.Context, userID uint64) ([]*core.B
return bots, nil
}

func (s *service) CreateBot(ctx context.Context,
id uint64,
name, model, prompt string,
temperature float32,
max_turn_count, context_turn_count int,
middlewares core.MiddlewareConfig,
public bool,
) (*core.Bot, error) {

func (s *service) CreateBot(ctx context.Context, bot *core.Bot) error {
// check model if exists
if _, err := s.models.GetModel(ctx, model); err != nil {
if _, err := s.models.GetModel(ctx, bot.Model); err != nil {
if store.IsNotFoundErr(err) {
return nil, core.ErrInvalidModel
return core.ErrInvalidModel
}
fmt.Printf("models.GetModel err: %v\n", err)
return nil, err
return err
}

botID, err := s.bots.CreateBot(ctx, id, name, model, prompt, temperature, max_turn_count, context_turn_count, middlewares, public)
botID, err := s.bots.CreateBot(ctx, bot)
if err != nil {
fmt.Printf("bots.CreateBot err: %v\n", err)
return nil, err
return err
}

bot, err := s.bots.GetBot(ctx, botID)
b, err := s.bots.GetBot(ctx, botID)
if err != nil {
fmt.Printf("bots.GetBot err: %v\n", err)
return nil, err
return err
}
*bot = *b

key := fmt.Sprintf("bot-%d", botID)

s.botCache.Set(key, bot, cache.DefaultExpiration)
s.botCache.Delete(fmt.Sprintf("user-bots-%d", id))
s.botCache.Delete(fmt.Sprintf("user-bots-%d", bot.UserID))

return bot, nil
return nil
}

func (s *service) UpdateBot(ctx context.Context, id uint64, name, model, prompt string, temperature float32, maxTurnCount, contextTurnCount int, middlewares core.MiddlewareConfig, public bool) error {
bot, err := s.bots.GetBot(ctx, id)
func (s *service) UpdateBot(ctx context.Context, bot *core.Bot) error {
err := s.bots.UpdateBot(ctx, bot)
if err != nil {
fmt.Printf("bots.GetBot err: %v\n", err)
if store.IsNotFoundErr(err) {
return core.ErrBotNotFound
}

fmt.Printf("bots.UpdateBot err: %v\n", err)
return err
}

err = s.bots.UpdateBot(ctx, id, name, model, prompt, temperature, maxTurnCount, contextTurnCount, middlewares, public)
b, err := s.bots.GetBot(ctx, bot.ID)
if err != nil {
fmt.Printf("bots.UpdateBot err: %v\n", err)
fmt.Printf("bots.GetBot err: %v\n", err)
return err
}
*bot = *b

s.botCache.Delete(fmt.Sprintf("bot-%d", id))
s.botCache.Delete(fmt.Sprintf("bot-%d", bot.ID))
s.botCache.Delete(fmt.Sprintf("user-bots-%d", bot.UserID))
return nil
}
Expand Down

0 comments on commit 30f9834

Please sign in to comment.