Skip to content

Commit

Permalink
Save conversations to database
Browse files Browse the repository at this point in the history
* Update /conversations/oneway API, add `conversation_id` and `timeout` params. If `conversation_id` is set, then use the specific conversation rather than create a new one.
  • Loading branch information
xwjdsh authored and lyricat committed Apr 11, 2023
1 parent 1b16e4a commit cf742e2
Show file tree
Hide file tree
Showing 10 changed files with 590 additions and 112 deletions.
2 changes: 1 addition & 1 deletion cmd/httpd/httpd.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ func NewCmdHttpd() *cobra.Command {

middlewarez := middlewareServ.New(middlewareServ.Config{}, apps, indexService)
botz := botServ.New(botServ.Config{}, apps, bots, models, middlewarez)
convz := convServ.New(convServ.Config{}, convs, botz)
convz := convServ.New(convServ.Config{}, convs, botz, apps)
orderz := orderServ.New(orderServ.Config{
PayeeId: cfg.Mixpay.PayeeId,
QuoteAssetId: cfg.Mixpay.QuoteAssetId,
Expand Down
63 changes: 30 additions & 33 deletions core/conv.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,18 @@ func (mr MiddlewareResults) Value() (driver.Value, error) {

type (
Conversation struct {
ID string `json:"id"`
Bot *Bot `json:"bot"`
App *App `json:"app"`
UserIdentity string `json:"user_identity"`
Lang string `json:"lang"`
History []*ConvTurn `json:"history"`
ExpiredAt time.Time `json:"expired_at"`
ID string `json:"id"`
Lang string `json:"lang"`
UserIdentity string `json:"user_identity"`
BotID uint64 `json:"bot_id"`
AppID uint64 `json:"app_id"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
DeletedAt *time.Time `json:"-"`

Bot *Bot `gorm:"-" json:"bot"`
App *App `gorm:"-" json:"app"`
History []*ConvTurn `gorm:"-" json:"history"`
}

ConvTurn struct {
Expand All @@ -81,6 +86,21 @@ type (
}

ConversationStore interface {

// INSERT INTO "conversations"
// (
// id, lang, user_identity, bot_id, app_id, created_at, updated_at
// ) VALUES (
// @conv.ID, @conv.Lang, @conv.UserIdentity, @conv.BotID, @conv.AppID, NOW(), NOW()
// )
CreateConversation(ctx context.Context, conv *Conversation) error

// SELECT * FROM "conversations" WHERE id = @id AND deleted_at IS NULL
GetConversation(ctx context.Context, id string) (*Conversation, error)

// SELECT * FROM "conv_turns" WHERE conversation_id = @conversationID ORDER BY id DESC LIMIT @limit
GetConvTurnsByConversationID(ctx context.Context, conversationID string, limit int) ([]*ConvTurn, error)

// INSERT INTO "conv_turns"
// (
// "conversation_id", "bot_id", "app_id", "user_id",
Expand All @@ -98,35 +118,17 @@ type (
// RETURNING "id"
CreateConvTurn(ctx context.Context, convID string, botID, appID, userID uint64, uid, request string, bo BotOverride) (uint64, error)

// SELECT
// "id",
// "conversation_id", "bot_id", "app_id", "user_id",
// "user_identity",
// "request", "response", "status",
// "prompt_tokens", "completion_tokens", "total_tokens", "bot_override",
// "created_at", "updated_at"
// SELECT *
// FROM "conv_turns" WHERE
// "id" IN (@ids)
GetConvTurns(ctx context.Context, ids []uint64) ([]*ConvTurn, error)

// SELECT
// "id",
// "conversation_id", "bot_id", "app_id", "user_id",
// "user_identity",
// "request", "response", "status",
// "prompt_tokens", "completion_tokens", "total_tokens", "bot_override", "middleware_results",
// "created_at", "updated_at"
// SELECT *
// FROM "conv_turns" WHERE
// "id" = @id
GetConvTurn(ctx context.Context, id uint64) (*ConvTurn, error)

// SELECT
// "id",
// "conversation_id", "bot_id", "app_id", "user_id",
// "user_identity",
// "request", "response", "status",
// "prompt_tokens", "completion_tokens", "total_tokens", "bot_override", "middleware_results",
// "created_at", "updated_at"
// SELECT *
// FROM "conv_turns"
// {{where}}
// "status" IN (@status)
Expand Down Expand Up @@ -155,18 +157,13 @@ type (

ConversationService interface {
CreateConversation(ctx context.Context, botID, appID uint64, userIdentity, lang string) (*Conversation, error)
ClearExpiredConversations(ctx context.Context) error
DeleteConversation(ctx context.Context, convID string) error
GetConversation(ctx context.Context, convID string) (*Conversation, error)
PostToConversation(ctx context.Context, conv *Conversation, input string, bo BotOverride) (*ConvTurn, error)
ReplaceStore(store ConversationStore) ConversationService
}
)

func (c *Conversation) IsExpired() bool {
return c.ExpiredAt.Before(time.Now())
}

func (c *Conversation) HistoryToText() string {
lines := make([]string, 0)
history := c.History
Expand Down
2 changes: 2 additions & 0 deletions core/error.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ func MatchFlag(err error, flag int) bool {
}

var (
ErrInternalServer = errors.New("internal server error")

ErrInsufficientCredit = errors.New("insufficient credit")

ErrMinAmountNotSatisfied = errors.New("min amount not satisfied")
Expand Down
46 changes: 27 additions & 19 deletions handler/conv/conv.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (

"github.com/fox-one/pkg/httputil/param"
"github.com/go-chi/chi"
"github.com/google/uuid"
"github.com/pandodao/botastic/core"
"github.com/pandodao/botastic/handler/render"
"github.com/pandodao/botastic/internal/chanhub"
Expand All @@ -35,10 +34,12 @@ type (
}

CreateOnewayConversationPayload struct {
BotID uint64 `json:"bot_id"`
CreateConversationPayload
ConversationID string `json:"conversation_id"`

BotOverride core.BotOverride `json:"bot_override"`
Content string `json:"content"`
Lang string `json:"lang"`
Timeout time.Duration `json:"timeout"`
}
)

Expand Down Expand Up @@ -135,22 +136,20 @@ func GetConversationTurn(botz core.BotService, convs core.ConversationStore, hub
func GetConversation(botz core.BotService, convz core.ConversationService) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
app := session.AppFrom(ctx)

conversationID := chi.URLParam(r, "conversationID")
if conversationID == "" {
render.Error(w, http.StatusBadRequest, nil)
return
}

conv, err := convz.GetConversation(ctx, conversationID)
if err != nil || conv == nil {
render.Error(w, http.StatusNotFound, err)
return
}

if conv.App.ID != app.ID {
render.Error(w, http.StatusNotFound, nil)
if err != nil {
switch err {
case core.ErrConvNotFound, core.ErrBotNotFound:
render.Error(w, http.StatusNotFound, err)
default:
render.Error(w, http.StatusInternalServerError, err)
}
return
}

Expand Down Expand Up @@ -249,9 +248,6 @@ func UpdateConversation() http.HandlerFunc {
func CreateOnewayConversation(convz core.ConversationService, convs core.ConversationStore, hub *chanhub.Hub) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()

app := session.AppFrom(ctx)

body := &CreateOnewayConversationPayload{}
Expand All @@ -260,6 +256,13 @@ func CreateOnewayConversation(convz core.ConversationService, convs core.Convers
return
}

timeout := 10 * time.Second
if body.Timeout > 0 {
timeout = body.Timeout
}
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()

if body.BotID <= 0 {
render.Error(w, http.StatusBadRequest, nil)
return
Expand All @@ -270,9 +273,16 @@ func CreateOnewayConversation(convz core.ConversationService, convs core.Convers
return
}

uid := uuid.New().String()
var (
conv *core.Conversation
err error
)
if body.ConversationID == "" {
conv, err = convz.CreateConversation(ctx, body.BotID, app.ID, body.UserIdentity, body.Lang)
} else {
conv, err = convz.GetConversation(ctx, body.ConversationID)
}

conv, err := convz.CreateConversation(ctx, body.BotID, app.ID, uid, body.Lang)
if err != nil {
render.Error(w, http.StatusInternalServerError, err)
return
Expand All @@ -285,7 +295,6 @@ func CreateOnewayConversation(convz core.ConversationService, convs core.Convers
}

turnIDStr := strconv.FormatUint(turn.ID, 10)

_, err = hub.AddAndWait(ctx, turnIDStr)
if err != nil {
if err == context.Canceled {
Expand All @@ -301,6 +310,5 @@ func CreateOnewayConversation(convz core.ConversationService, convs core.Convers
}

render.JSON(w, turn)

}
}
86 changes: 69 additions & 17 deletions service/conv/conv.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,20 @@ import (
"github.com/google/uuid"
"github.com/pandodao/botastic/core"
"github.com/pandodao/botastic/session"
"github.com/pandodao/botastic/store"
)

func New(
cfg Config,
convs core.ConversationStore,
botz core.BotService,
apps core.AppStore,
) *service {
return &service{
cfg: cfg,
convs: convs,
botz: botz,
apps: apps,
conversationMap: make(map[string]*core.Conversation),
}
}
Expand All @@ -30,12 +33,13 @@ type (
cfg Config
convs core.ConversationStore
botz core.BotService
apps core.AppStore
conversationMap map[string]*core.Conversation
}
)

func (s *service) ReplaceStore(convs core.ConversationStore) core.ConversationService {
return New(s.cfg, convs, s.botz)
return New(s.cfg, convs, s.botz, s.apps)
}

func (s *service) CreateConversation(ctx context.Context, botID, appID uint64, userIdentity, lang string) (*core.Conversation, error) {
Expand All @@ -49,17 +53,77 @@ func (s *service) CreateConversation(ctx context.Context, botID, appID uint64, u
if !bot.Public && app.UserID != bot.UserID {
return nil, core.ErrBotNotFound
}
if lang == "" {
lang = "en"
}

conv := s.getDefaultConversation(app, bot, userIdentity, lang)
s.conversationMap[conv.ID] = conv
now := time.Now()
conv := &core.Conversation{
ID: uuid.New().String(),
AppID: app.ID,
BotID: bot.ID,
UserIdentity: userIdentity,
Lang: lang,
CreatedAt: now,
UpdatedAt: now,
Bot: bot,
App: app,
}

if err := s.convs.CreateConversation(ctx, conv); err != nil {
return nil, err
}

s.conversationMap[conv.ID] = conv
return conv, nil
}

func (s *service) GetConversation(ctx context.Context, convID string) (*core.Conversation, error) {
conv, ok := s.conversationMap[convID]
if !ok {
return nil, core.ErrConvNotFound
// load from db
var err error
conv, err = s.convs.GetConversation(ctx, convID)
if err != nil {
if store.IsNotFoundErr(err) {
return nil, core.ErrConvNotFound
}
return nil, core.ErrInternalServer
}

app := session.AppFrom(ctx)
if app != nil {
if conv.AppID != app.ID {
return nil, core.ErrConvNotFound
}
conv.App = app
} else {
app, err := s.apps.GetApp(ctx, conv.AppID)
if err != nil {
if store.IsNotFoundErr(err) {
return nil, core.ErrAppNotFound
}
return nil, core.ErrInternalServer
}
conv.App = app
}

bot, err := s.botz.GetBot(ctx, conv.BotID)
if err != nil {
if store.IsNotFoundErr(err) {
return nil, core.ErrBotNotFound
}
return nil, core.ErrInternalServer
}
conv.Bot = bot

// load history
turns, err := s.convs.GetConvTurnsByConversationID(ctx, conv.ID, bot.MaxTurnCount)
if err != nil {
return nil, core.ErrInternalServer
}

conv.History = turns
}

ids := []uint64{}
Expand Down Expand Up @@ -97,13 +161,11 @@ func (s *service) PostToConversation(ctx context.Context, conv *core.Conversatio
return nil, err
}

turns, err := s.convs.GetConvTurns(ctx, []uint64{turnID})
turn, err := s.convs.GetConvTurn(ctx, turnID)
if err != nil {
return nil, err
}

turn := turns[0]

bot, err := s.botz.GetBot(ctx, conv.Bot.ID)
if err != nil {
return nil, err
Expand All @@ -119,15 +181,6 @@ func (s *service) PostToConversation(ctx context.Context, conv *core.Conversatio
return turn, nil
}

func (s *service) ClearExpiredConversations(ctx context.Context) error {
for key, conv := range s.conversationMap {
if conv.IsExpired() {
delete(s.conversationMap, key)
}
}
return nil
}

func (s *service) DeleteConversation(ctx context.Context, convID string) error {
delete(s.conversationMap, convID)
return nil
Expand All @@ -144,6 +197,5 @@ func (s *service) getDefaultConversation(app *core.App, bot *core.Bot, uid, lang
UserIdentity: uid,
Lang: lang,
History: []*core.ConvTurn{},
ExpiredAt: time.Now().Add(10 * time.Minute),
}
}
2 changes: 1 addition & 1 deletion store/conv/conv.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ func init() {
OutPath: "store/conv/dao",
},
func(g *gen.Generator) {
g.ApplyInterface(func(core.ConversationStore) {}, core.ConvTurn{})
g.ApplyInterface(func(core.ConversationStore) {}, core.ConvTurn{}, core.Conversation{})
},
)
}
Expand Down

0 comments on commit cf742e2

Please sign in to comment.