-
-
Notifications
You must be signed in to change notification settings - Fork 696
/
Copy pathoption.go
210 lines (183 loc) · 5.95 KB
/
option.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
package googleai
import (
"net/http"
"os"
"reflect"
"cloud.google.com/go/vertexai/genai"
"google.golang.org/api/option"
)
// Options is a set of options for GoogleAI and Vertex clients.
type Options struct {
CloudProject string
CloudLocation string
DefaultModel string
DefaultEmbeddingModel string
DefaultCandidateCount int
DefaultMaxTokens int
DefaultTemperature float64
DefaultTopK int
DefaultTopP float64
HarmThreshold HarmBlockThreshold
ClientOptions []option.ClientOption
}
func DefaultOptions() Options {
return Options{
CloudProject: "",
CloudLocation: "",
DefaultModel: "gemini-pro",
DefaultEmbeddingModel: "embedding-001",
DefaultCandidateCount: 1,
DefaultMaxTokens: 2048,
DefaultTemperature: 0.5,
DefaultTopK: 3,
DefaultTopP: 0.95,
HarmThreshold: HarmBlockOnlyHigh,
}
}
// EnsureAuthPresent attempts to ensure that the client has authentication information. If it does not, it will attempt to use the GOOGLE_API_KEY environment variable.
func (o *Options) EnsureAuthPresent() {
if !hasAuthOptions(o.ClientOptions) {
if key := os.Getenv("GOOGLE_API_KEY"); key != "" {
WithAPIKey(key)(o)
}
}
}
type Option func(*Options)
// WithAPIKey passes the API KEY (token) to the client. This is useful for
// googleai clients.
func WithAPIKey(apiKey string) Option {
return func(opts *Options) {
opts.ClientOptions = append(opts.ClientOptions, option.WithAPIKey(apiKey))
}
}
// WithCredentialsJSON append a ClientOption that authenticates
// API calls with the given service account or refresh token JSON
// credentials.
func WithCredentialsJSON(credentialsJSON []byte) Option {
return func(opts *Options) {
if len(credentialsJSON) == 0 {
return
}
opts.ClientOptions = append(opts.ClientOptions, option.WithCredentialsJSON(credentialsJSON))
}
}
// WithCredentialsFile append a ClientOption that authenticates
// API calls with the given service account or refresh token JSON
// credentials file.
func WithCredentialsFile(credentialsFile string) Option {
return func(opts *Options) {
if credentialsFile == "" {
return
}
opts.ClientOptions = append(opts.ClientOptions, option.WithCredentialsFile(credentialsFile))
}
}
// WithRest configures the client to use the REST API.
func WithRest() Option {
return func(opts *Options) {
opts.ClientOptions = append(opts.ClientOptions, genai.WithREST())
}
}
// WithHTTPClient append a ClientOption that uses the provided HTTP client to
// make requests.
// This is useful for vertex clients.
func WithHTTPClient(httpClient *http.Client) Option {
return func(opts *Options) {
opts.ClientOptions = append(opts.ClientOptions, option.WithHTTPClient(httpClient))
}
}
// WithCloudProject passes the GCP cloud project name to the client. This is
// useful for vertex clients.
func WithCloudProject(p string) Option {
return func(opts *Options) {
opts.CloudProject = p
}
}
// WithCloudLocation passes the GCP cloud location (region) name to the client.
// This is useful for vertex clients.
func WithCloudLocation(l string) Option {
return func(opts *Options) {
opts.CloudLocation = l
}
}
// WithDefaultModel passes a default content model name to the client. This
// model name is used if not explicitly provided in specific client invocations.
func WithDefaultModel(defaultModel string) Option {
return func(opts *Options) {
opts.DefaultModel = defaultModel
}
}
// WithDefaultModel passes a default embedding model name to the client. This
// model name is used if not explicitly provided in specific client invocations.
func WithDefaultEmbeddingModel(defaultEmbeddingModel string) Option {
return func(opts *Options) {
opts.DefaultEmbeddingModel = defaultEmbeddingModel
}
}
// WithDefaultCandidateCount sets the candidate count for the model.
func WithDefaultCandidateCount(defaultCandidateCount int) Option {
return func(opts *Options) {
opts.DefaultCandidateCount = defaultCandidateCount
}
}
// WithDefaultMaxTokens sets the maximum token count for the model.
func WithDefaultMaxTokens(maxTokens int) Option {
return func(opts *Options) {
opts.DefaultMaxTokens = maxTokens
}
}
// WithDefaultTemperature sets the maximum token count for the model.
func WithDefaultTemperature(defaultTemperature float64) Option {
return func(opts *Options) {
opts.DefaultTemperature = defaultTemperature
}
}
// WithDefaultTopK sets the TopK for the model.
func WithDefaultTopK(defaultTopK int) Option {
return func(opts *Options) {
opts.DefaultTopK = defaultTopK
}
}
// WithDefaultTopP sets the TopP for the model.
func WithDefaultTopP(defaultTopP float64) Option {
return func(opts *Options) {
opts.DefaultTopP = defaultTopP
}
}
// WithHarmThreshold sets the safety/harm setting for the model, potentially
// limiting any harmful content it may generate.
func WithHarmThreshold(ht HarmBlockThreshold) Option {
return func(opts *Options) {
opts.HarmThreshold = ht
}
}
type HarmBlockThreshold int32
const (
// HarmBlockUnspecified means threshold is unspecified.
HarmBlockUnspecified HarmBlockThreshold = 0
// HarmBlockLowAndAbove means content with NEGLIGIBLE will be allowed.
HarmBlockLowAndAbove HarmBlockThreshold = 1
// HarmBlockMediumAndAbove means content with NEGLIGIBLE and LOW will be allowed.
HarmBlockMediumAndAbove HarmBlockThreshold = 2
// HarmBlockOnlyHigh means content with NEGLIGIBLE, LOW, and MEDIUM will be allowed.
HarmBlockOnlyHigh HarmBlockThreshold = 3
// HarmBlockNone means all content will be allowed.
HarmBlockNone HarmBlockThreshold = 4
)
// helper to inspect incoming client options for auth options.
func hasAuthOptions(opts []option.ClientOption) bool {
for _, opt := range opts {
v := reflect.ValueOf(opt)
ts := v.Type().String()
switch ts {
case "option.withAPIKey":
return v.String() != ""
case "option.withHTTPClient",
"option.withTokenSource",
"option.withCredentialsFile",
"option.withCredentialsJSON":
return true
}
}
return false
}