-
Notifications
You must be signed in to change notification settings - Fork 411
/
gpt.go
73 lines (63 loc) · 1.91 KB
/
gpt.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
package chat
import (
"context"
"os"
"github.com/pwh-pwh/aiwechat-vercel/config"
"github.com/pwh-pwh/aiwechat-vercel/db"
"github.com/sashabaranov/go-openai"
)
type SimpleGptChat struct {
token string
url string
maxTokens int
BaseChat
}
func (s *SimpleGptChat) toDbMsg(msg openai.ChatCompletionMessage) db.Msg {
return db.Msg{
Role: msg.Role,
Msg: msg.Content,
}
}
func (s *SimpleGptChat) toChatMsg(msg db.Msg) openai.ChatCompletionMessage {
return openai.ChatCompletionMessage{
Role: msg.Role,
Content: msg.Msg,
}
}
func (s *SimpleGptChat) getModel(userID string) string {
if model, err := db.GetModel(userID, config.Bot_Type_Gpt); err == nil && model != "" {
return model
} else if model = os.Getenv("gptModel"); model != "" {
return model
}
return "gpt-3.5-turbo"
}
func (s *SimpleGptChat) chat(userID, msg string) string {
cfg := openai.DefaultConfig(s.token)
cfg.BaseURL = s.url
client := openai.NewClientWithConfig(cfg)
var msgs = GetMsgListWithDb(config.Bot_Type_Gpt, userID, openai.ChatCompletionMessage{Role: openai.ChatMessageRoleUser, Content: msg}, s.toDbMsg, s.toChatMsg)
req := openai.ChatCompletionRequest{
Model: s.getModel(userID),
Messages: msgs,
}
// 如果设置了环境变量且合法,则增加maxTokens参数,否则不设置
if s.maxTokens > 0 {
req.MaxTokens = s.maxTokens // 参数名称参考:https://github.com/sashabaranov/go-openai
}
resp, err := client.CreateChatCompletion(context.Background(), req)
if err != nil {
return err.Error()
}
content := resp.Choices[0].Message.Content
msgs = append(msgs, openai.ChatCompletionMessage{Role: openai.ChatMessageRoleAssistant, Content: content})
SaveMsgListWithDb(config.Bot_Type_Gpt, userID, msgs, s.toDbMsg)
return content
}
func (s *SimpleGptChat) Chat(userID string, msg string) string {
r, flag := DoAction(userID, msg)
if flag {
return r
}
return WithTimeChat(userID, msg, s.chat)
}