forked from tmc/langchaingo
-
Notifications
You must be signed in to change notification settings - Fork 0
/
vertexai_palm_llm_chat.go
138 lines (121 loc) · 3.6 KB
/
vertexai_palm_llm_chat.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
package vertexai
import (
"context"
"strings"
"github.com/portyl/langchaingo/llms"
"github.com/portyl/langchaingo/llms/vertexai/internal/vertexaiclient"
"github.com/portyl/langchaingo/schema"
)
const (
userAuthor = "user"
botAuthor = "bot"
)
type ChatMessage = vertexaiclient.ChatMessage
type Chat struct {
client *vertexaiclient.PaLMClient
}
var (
_ llms.ChatLLM = (*Chat)(nil)
_ llms.LanguageModel = (*Chat)(nil)
)
// Chat requests a chat response for the given messages.
func (o *Chat) Call(ctx context.Context, messages []schema.ChatMessage, options ...llms.CallOption) (*schema.AIChatMessage, error) { // nolint: lll
r, err := o.Generate(ctx, [][]schema.ChatMessage{messages}, options...)
if err != nil {
return nil, err
}
if len(r) == 0 {
return nil, ErrEmptyResponse
}
return r[0].Message, nil
}
// Generate requests a chat response for each of the sets of messages.
func (o *Chat) Generate(ctx context.Context, messageSets [][]schema.ChatMessage, options ...llms.CallOption) ([]*llms.Generation, error) { // nolint: lll
opts := llms.CallOptions{}
for _, opt := range options {
opt(&opts)
}
generations := make([]*llms.Generation, 0, len(messageSets))
for _, messages := range messageSets {
var contextPieces []string
for _, m := range messages {
if m.GetType() == schema.ChatMessageTypeSystem {
contextPieces = append(contextPieces, m.GetContent())
}
}
msgs := toClientChatMessage(messages)
result, err := o.client.CreateChat(ctx, &vertexaiclient.ChatRequest{
Context: strings.Join(contextPieces, "\n"),
Temperature: opts.Temperature,
TopP: opts.TopP,
TopK: opts.TopK,
MaxOutputTokens: opts.MaxTokens,
Messages: msgs,
Model: opts.Model,
StreamingFunc: opts.StreamingFunc,
})
if err != nil {
return nil, err
}
if len(result.Candidates) == 0 {
return nil, ErrEmptyResponse
}
generations = append(generations, &llms.Generation{
Message: &schema.AIChatMessage{
Content: result.Candidates[0].Content,
},
Text: result.Candidates[0].Content,
})
}
return generations, nil
}
func (o *Chat) GeneratePrompt(ctx context.Context, promptValues []schema.PromptValue, options ...llms.CallOption) (llms.LLMResult, error) { //nolint:lll
return llms.GenerateChatPrompt(ctx, o, promptValues, options...)
}
func (o *Chat) GetNumTokens(text string) int {
return llms.CountTokens(vertexaiclient.TextModelName, text)
}
func toClientChatMessage(messages []schema.ChatMessage) []*vertexaiclient.ChatMessage {
var msgs []*vertexaiclient.ChatMessage
for _, m := range messages {
msg := &vertexaiclient.ChatMessage{
Content: m.GetContent(),
}
typ := m.GetType()
switch typ {
case schema.ChatMessageTypeAI:
msg.Author = botAuthor
case schema.ChatMessageTypeHuman:
msg.Author = userAuthor
}
if msg.Author == "" {
continue
}
if n, ok := m.(schema.Named); ok {
msg.Author = n.GetName()
}
msgs = append(msgs, msg)
}
return msgs
}
// NewChat returns a new VertexAI PaLM Chat LLM.
func NewChat(opts ...Option) (*Chat, error) {
client, err := newClient(opts...)
return &Chat{client: client}, err
}
// CreateEmbedding creates embeddings for the given input texts.
func (o *Chat) CreateEmbedding(ctx context.Context, inputTexts []string) ([][]float64, error) {
embeddings, err := o.client.CreateEmbedding(ctx, &vertexaiclient.EmbeddingRequest{
Input: inputTexts,
})
if err != nil {
return nil, err
}
if len(embeddings) == 0 {
return nil, ErrEmptyResponse
}
if len(inputTexts) != len(embeddings) {
return embeddings, ErrUnexpectedResponseLength
}
return embeddings, nil
}