forked from ankit-arora/langchaingo
-
Notifications
You must be signed in to change notification settings - Fork 0
/
option.go
132 lines (116 loc) · 3.83 KB
/
option.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
package googleai
// Options is a set of options for GoogleAI and Vertex clients.
type Options struct {
APIKey string
CloudProject string
CloudLocation string
DefaultModel string
DefaultEmbeddingModel string
DefaultCandidateCount int
DefaultMaxTokens int
DefaultTemperature float64
DefaultTopK int
DefaultTopP float64
HarmThreshold HarmBlockThreshold
}
func DefaultOptions() Options {
return Options{
APIKey: "",
CloudProject: "",
CloudLocation: "",
DefaultModel: "gemini-pro",
DefaultEmbeddingModel: "embedding-001",
DefaultCandidateCount: 1,
DefaultMaxTokens: 1024,
DefaultTemperature: 0.5,
DefaultTopK: 3,
DefaultTopP: 0.95,
HarmThreshold: HarmBlockOnlyHigh,
}
}
type Option func(*Options)
// WithAPIKey passes the API KEY (token) to the client. This is useful for
// googleai clients.
func WithAPIKey(apiKey string) Option {
return func(opts *Options) {
opts.APIKey = apiKey
}
}
// WithCloudProject passes the GCP cloud project name to the client. This is
// useful for vertex clients.
func WithCloudProject(p string) Option {
return func(opts *Options) {
opts.CloudProject = p
}
}
// WithCloudLocation passes the GCP cloud location (region) name to the client.
// This is useful for vertex clients.
func WithCloudLocation(l string) Option {
return func(opts *Options) {
opts.CloudLocation = l
}
}
// WithDefaultModel passes a default content model name to the client. This
// model name is used if not explicitly provided in specific client invocations.
func WithDefaultModel(defaultModel string) Option {
return func(opts *Options) {
opts.DefaultModel = defaultModel
}
}
// WithDefaultModel passes a default embedding model name to the client. This
// model name is used if not explicitly provided in specific client invocations.
func WithDefaultEmbeddingModel(defaultEmbeddingModel string) Option {
return func(opts *Options) {
opts.DefaultEmbeddingModel = defaultEmbeddingModel
}
}
// WithDefaultCandidateCount sets the candidate count for the model.
func WithDefaultCandidateCount(defaultCandidateCount int) Option {
return func(opts *Options) {
opts.DefaultCandidateCount = defaultCandidateCount
}
}
// WithDefaultMaxTokens sets the maximum token count for the model.
func WithDefaultMaxTokens(maxTokens int) Option {
return func(opts *Options) {
opts.DefaultMaxTokens = maxTokens
}
}
// WithDefaultTemperature sets the maximum token count for the model.
func WithDefaultTemperature(defaultTemperature float64) Option {
return func(opts *Options) {
opts.DefaultTemperature = defaultTemperature
}
}
// WithDefaultTopK sets the TopK for the model.
func WithDefaultTopK(defaultTopK int) Option {
return func(opts *Options) {
opts.DefaultTopK = defaultTopK
}
}
// WithDefaultTopP sets the TopP for the model.
func WithDefaultTopP(defaultTopP float64) Option {
return func(opts *Options) {
opts.DefaultTopP = defaultTopP
}
}
// WithHarmThreshold sets the safety/harm setting for the model, potentially
// limiting any harmful content it may generate.
func WithHarmThreshold(ht HarmBlockThreshold) Option {
return func(opts *Options) {
opts.HarmThreshold = ht
}
}
type HarmBlockThreshold int32
const (
// HarmBlockUnspecified means threshold is unspecified.
HarmBlockUnspecified HarmBlockThreshold = 0
// HarmBlockLowAndAbove means content with NEGLIGIBLE will be allowed.
HarmBlockLowAndAbove HarmBlockThreshold = 1
// HarmBlockMediumAndAbove means content with NEGLIGIBLE and LOW will be allowed.
HarmBlockMediumAndAbove HarmBlockThreshold = 2
// HarmBlockOnlyHigh means content with NEGLIGIBLE, LOW, and MEDIUM will be allowed.
HarmBlockOnlyHigh HarmBlockThreshold = 3
// HarmBlockNone means all content will be allowed.
HarmBlockNone HarmBlockThreshold = 4
)