forked from tmc/langchaingo
-
Notifications
You must be signed in to change notification settings - Fork 0
/
vertexai_palm_llm.go
120 lines (101 loc) · 3.22 KB
/
vertexai_palm_llm.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
package vertexai
import (
"context"
"errors"
"github.com/portyl/langchaingo/callbacks"
"github.com/portyl/langchaingo/llms"
"github.com/portyl/langchaingo/llms/vertexai/internal/vertexaiclient"
"github.com/portyl/langchaingo/schema"
)
var (
ErrEmptyResponse = errors.New("no response")
ErrMissingProjectID = errors.New("missing the GCP Project ID, set it in the GOOGLE_CLOUD_PROJECT environment variable") //nolint:lll
ErrUnexpectedResponseLength = errors.New("unexpected length of response")
ErrNotImplemented = errors.New("not implemented")
)
type LLM struct {
CallbacksHandler callbacks.Handler
client *vertexaiclient.PaLMClient
}
var (
_ llms.LLM = (*LLM)(nil)
_ llms.LanguageModel = (*LLM)(nil)
)
// Call requests a completion for the given prompt.
func (o *LLM) Call(ctx context.Context, prompt string, options ...llms.CallOption) (string, error) {
r, err := o.Generate(ctx, []string{prompt}, options...)
if err != nil {
return "", err
}
if len(r) == 0 {
return "", ErrEmptyResponse
}
return r[0].Text, nil
}
func (o *LLM) Generate(ctx context.Context, prompts []string, options ...llms.CallOption) ([]*llms.Generation, error) {
if o.CallbacksHandler != nil {
o.CallbacksHandler.HandleLLMStart(ctx, prompts)
}
opts := llms.CallOptions{}
for _, opt := range options {
opt(&opts)
}
results, err := o.client.CreateCompletion(ctx, &vertexaiclient.CompletionRequest{
Prompts: prompts,
MaxTokens: opts.MaxTokens,
Temperature: opts.Temperature,
})
if err != nil {
return nil, err
}
generations := []*llms.Generation{}
for _, r := range results {
generations = append(generations, &llms.Generation{
Text: r.Text,
})
}
if o.CallbacksHandler != nil {
o.CallbacksHandler.HandleLLMEnd(ctx, llms.LLMResult{Generations: [][]*llms.Generation{generations}})
}
return generations, nil
}
// CreateEmbedding creates embeddings for the given input texts.
func (o *LLM) CreateEmbedding(ctx context.Context, inputTexts []string) ([][]float64, error) {
embeddings, err := o.client.CreateEmbedding(ctx, &vertexaiclient.EmbeddingRequest{
Input: inputTexts,
})
if err != nil {
return [][]float64{}, err
}
if len(embeddings) == 0 {
return [][]float64{}, ErrEmptyResponse
}
if len(inputTexts) != len(embeddings) {
return embeddings, ErrUnexpectedResponseLength
}
return embeddings, nil
}
func (o *LLM) GeneratePrompt(ctx context.Context, promptValues []schema.PromptValue, options ...llms.CallOption) (llms.LLMResult, error) { //nolint:lll
return llms.GeneratePrompt(ctx, o, promptValues, options...)
}
func (o *LLM) GetNumTokens(text string) int {
return llms.CountTokens(vertexaiclient.TextModelName, text)
}
// New returns a new VertexAI PaLM LLM.
func New(opts ...Option) (*LLM, error) {
client, err := newClient(opts...)
return &LLM{client: client}, err
}
func newClient(opts ...Option) (*vertexaiclient.PaLMClient, error) {
// Ensure options are initialized only once.
initOptions.Do(initOpts)
options := &options{}
*options = *defaultOptions // Copy default options.
for _, opt := range opts {
opt(options)
}
if len(options.projectID) == 0 {
return nil, ErrMissingProjectID
}
return vertexaiclient.New(options.projectID, options.clientOptions...)
}