forked from tmc/langchaingo
-
Notifications
You must be signed in to change notification settings - Fork 0
/
geminiclient.go
134 lines (114 loc) · 3.45 KB
/
geminiclient.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
package geminiclient
import (
"context"
"errors"
"fmt"
"net/http"
"github.com/google/generative-ai-go/genai"
"google.golang.org/api/iterator"
"google.golang.org/api/option"
)
const (
defaultBaseURL = ""
defaultModel = "gemini-pro"
defaultVisionModel = "gemini-pro-vision"
defaultEmbedModel = "embedding-001"
)
// ErrEmptyResponse is returned when the OpenAI API returns an empty response.
var ErrEmptyResponse = errors.New("empty response")
var ErrMissToken = errors.New("api key is not set")
var ErrRateLimitResponse = errors.New("rate limit reached for text-embedding-ada-002 in organization")
// Client is a client for the OpenAI API.
type Client struct {
token string
Model string
baseURL string
embedModel string
}
// New returns a new Gemini client.
func New(token string, model string, baseURL string) (*Client, error) {
c := &Client{
token: token,
Model: model,
baseURL: baseURL,
embedModel: defaultEmbedModel,
}
if c.Model == "" {
c.Model = defaultModel
}
if c.baseURL == "" {
c.baseURL = defaultBaseURL
}
return c, nil
}
// Completion is a completion.
type Completion struct {
Text string `json:"text"`
}
// CompletionRequest is a request to complete a completion.
type CompletionRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
Temperature float64 `json:"temperature,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
N int `json:"n,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
TopP float64 `json:"top_p,omitempty"`
StopWords []string `json:"stop,omitempty"`
Images [][]byte `json:"images"`
// StreamingFunc is a function to be called for each chunk of a streaming response.
// Return an error to stop streaming early.
StreamingFunc func(ctx context.Context, chunk []byte) error `json:"-"`
}
// CreateCompletion creates a completion.
func (c *Client) CreateCompletion(ctx context.Context, r *CompletionRequest) (*Completion, error) {
// Access your API key as an environment variable (see "Set up your API key" above)
opts := make([]option.ClientOption, 0, 2)
opts = append(opts, option.WithAPIKey(c.token))
if c.baseURL != "" {
opts = append(opts, option.WithEndpoint(c.baseURL))
}
client, err := genai.NewClient(ctx, opts...)
if err != nil {
return nil, err
}
defer client.Close()
model := client.GenerativeModel(c.Model)
if len(r.Images) > 0 {
model = client.GenerativeModel(defaultVisionModel)
}
model.StopSequences = r.StopWords
// model.SetTemperature(float32(r.Temperature))
// model.SetTopP(float32(r.TopP))
if r.MaxTokens > 0 {
model.SetMaxOutputTokens(int32(r.MaxTokens))
}
blobs := make([]genai.Part, 0, len(r.Images)+1)
for _, image := range r.Images {
blobs = append(blobs, genai.ImageData("jpeg", image))
}
blobs = append(blobs, genai.Text(r.Prompt))
iter := model.GenerateContentStream(ctx, blobs...)
content := ""
for {
resp, err := iter.Next()
if err == iterator.Done {
break
}
if err != nil {
return nil, err
}
stream := fmt.Sprintf("%s", resp.Candidates[0].Content.Parts[0])
if r.StreamingFunc != nil {
_ = r.StreamingFunc(ctx, []byte(stream))
}
content += stream
}
return &Completion{
Text: content,
}, nil
}
func (c *Client) setHeaders(req *http.Request) {
req.Header.Set("Content-Type", "application/json")
}