-
Notifications
You must be signed in to change notification settings - Fork 8
/
llm_google.jl
222 lines (199 loc) · 9.31 KB
/
llm_google.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
## Rendering of converation history for the OpenAI API
## No system message, we need to merge with UserMessage, see below
function role4render(schema::AbstractGoogleSchema, msg::SystemMessage)
"user"
end
function role4render(schema::AbstractGoogleSchema, msg::AIMessage)
"model"
end
"""
render(schema::AbstractGoogleSchema,
messages::Vector{<:AbstractMessage};
conversation::AbstractVector{<:AbstractMessage} = AbstractMessage[],
kwargs...)
Builds a history of the conversation to provide the prompt to the API. All unspecified kwargs are passed as replacements such that `{{key}}=>value` in the template.
# Keyword Arguments
- `conversation`: An optional vector of `AbstractMessage` objects representing the conversation history. If not provided, it is initialized as an empty vector.
"""
function render(schema::AbstractGoogleSchema,
messages::Vector{<:AbstractMessage};
conversation::AbstractVector{<:AbstractMessage} = AbstractMessage[],
kwargs...)
##
## First pass: keep the message types but make the replacements provided in `kwargs`
messages_replaced = render(NoSchema(), messages; conversation, kwargs...)
## Second pass: convert to the OpenAI schema
conversation = Dict{Symbol, Any}[]
# replace any handlebar variables in the messages
for msg in messages_replaced
push!(conversation,
Dict(
:role => role4render(schema, msg), :parts => [Dict("text" => msg.content)]))
end
## Merge any subsequent UserMessages
merged_conversation = Dict{Symbol, Any}[]
# run n-1 times, look at the current item and the next one
i = 1
while i <= (length(conversation) - 1)
next_i = i + 1
if conversation[i][:role] == "user" && conversation[next_i][:role] == "user"
## Concat the user messages to together, put two newlines
txt1 = conversation[i][:parts][1]["text"]
txt2 = conversation[next_i][:parts][1]["text"]
merged_text = isempty(txt1) || isempty(txt2) ? txt1 * txt2 :
txt1 * "\n\n" * txt2
new_msg = Dict(:role => "user", :parts => [Dict("text" => merged_text)])
push!(merged_conversation, new_msg)
i += 2
else
push!(merged_conversation, conversation[i])
i += 1
end
end
## Add last message
if i == length(conversation)
push!(merged_conversation, conversation[end])
end
return merged_conversation
end
"Stub - to be extended in extension: GoogleGenAIPromptingToolsExt. `ggi` stands for GoogleGenAI"
function ggi_generate_content end
function ggi_generate_content(schema::TestEchoGoogleSchema, api_key::AbstractString,
model::AbstractString,
conversation; kwargs...)
schema.model_id = model
schema.inputs = conversation
return schema
end
## User-Facing API
"""
aigenerate(prompt_schema::AbstractGoogleSchema, prompt::ALLOWED_PROMPT_TYPE;
verbose::Bool = true,
api_key::String = GOOGLE_API_KEY,
model::String = "gemini-pro", return_all::Bool = false, dry_run::Bool = false,
http_kwargs::NamedTuple = (retry_non_idempotent = true,
retries = 5,
readtimeout = 120), api_kwargs::NamedTuple = NamedTuple(),
kwargs...)
Generate an AI response based on a given prompt using the Google Gemini API. Get the API key [here](https://ai.google.dev/).
Note:
- There is no "cost" reported as of February 2024, as all access seems to be free-of-charge. See the details [here](https://ai.google.dev/pricing).
- `tokens` in the returned AIMessage are actually characters, not tokens. We use a _conservative_ estimate as they are not provided by the API yet.
# Arguments
- `prompt_schema`: An optional object to specify which prompt template should be applied (Default to `PROMPT_SCHEMA = OpenAISchema`)
- `prompt`: Can be a string representing the prompt for the AI conversation, a `UserMessage`, a vector of `AbstractMessage` or an `AITemplate`
- `verbose`: A boolean indicating whether to print additional information.
- `api_key`: A string representing the API key for accessing the OpenAI API.
- `model`: A string representing the model to use for generating the response. Can be an alias corresponding to a model ID defined in `MODEL_ALIASES`. Defaults to
- `return_all::Bool=false`: If `true`, returns the entire conversation history, otherwise returns only the last message (the `AIMessage`).
- `dry_run::Bool=false`: If `true`, skips sending the messages to the model (for debugging, often used with `return_all=true`).
- `conversation`: An optional vector of `AbstractMessage` objects representing the conversation history. If not provided, it is initialized as an empty vector.
- `http_kwargs`: A named tuple of HTTP keyword arguments.
- `api_kwargs`: A named tuple of API keyword arguments.
- `kwargs`: Prompt variables to be used to fill the prompt/template
# Returns
If `return_all=false` (default):
- `msg`: An `AIMessage` object representing the generated AI message, including the content, status, tokens, and elapsed time.
Use `msg.content` to access the extracted string.
If `return_all=true`:
- `conversation`: A vector of `AbstractMessage` objects representing the conversation history, including the response from the AI model (`AIMessage`).
See also: `ai_str`, `aai_str`, `aiembed`, `aiclassify`, `aiextract`, `aiscan`, `aitemplates`
# Example
Simple hello world to test the API:
```julia
result = aigenerate("Say Hi!"; model="gemini-pro")
# AIMessage("Hi there! 👋 I'm here to help you with any questions or tasks you may have. Just let me know what you need, and I'll do my best to assist you.")
```
`result` is an `AIMessage` object. Access the generated string via `content` property:
```julia
typeof(result) # AIMessage{SubString{String}}
propertynames(result) # (:content, :status, :tokens, :elapsed
result.content # "Hi there! ...
```
___
You can use string interpolation and alias "gemini":
```julia
a = 1
msg=aigenerate("What is `\$a+\$a`?"; model="gemini")
msg.content # "1+1 is 2."
```
___
You can provide the whole conversation or more intricate prompts as a `Vector{AbstractMessage}`:
```julia
const PT = PromptingTools
conversation = [
PT.SystemMessage("You're master Yoda from Star Wars trying to help the user become a Yedi."),
PT.UserMessage("I have feelings for my iPhone. What should I do?")]
msg=aigenerate(conversation; model="gemini")
# AIMessage("Young Padawan, you have stumbled into a dangerous path.... <continues>")
```
"""
function aigenerate(prompt_schema::AbstractGoogleSchema, prompt::ALLOWED_PROMPT_TYPE;
verbose::Bool = true,
api_key::String = GOOGLE_API_KEY,
model::String = "gemini-pro", return_all::Bool = false, dry_run::Bool = false,
conversation::AbstractVector{<:AbstractMessage} = AbstractMessage[],
http_kwargs::NamedTuple = (retry_non_idempotent = true,
retries = 5,
readtimeout = 120), api_kwargs::NamedTuple = NamedTuple(),
kwargs...)
##
global MODEL_ALIASES
## Check that package GoogleGenAI is loaded
ext = Base.get_extension(PromptingTools, :GoogleGenAIPromptingToolsExt)
if isnothing(ext) && !(prompt_schema isa TestEchoGoogleSchema)
throw(ArgumentError("You need to also import GoogleGenAI package to use this function"))
end
## Find the unique ID for the model alias provided
model_id = get(MODEL_ALIASES, model, model)
conv_rendered = render(prompt_schema, prompt; conversation, kwargs...)
if !dry_run
time = @elapsed r = ggi_generate_content(prompt_schema, api_key,
model_id,
conv_rendered;
http_kwargs,
api_kwargs...)
## Big overestimate
input_token_estimate = length(JSON3.write(conv_rendered))
output_token_estimate = length(r.text)
msg = AIMessage(;
content = r.text |> strip,
status = convert(Int, r.response_status),
## for google it's CHARACTERS, not tokens
tokens = (input_token_estimate, output_token_estimate),
elapsed = time)
## Reporting
verbose && @info _report_stats(msg, model_id)
else
msg = nothing
end
## Select what to return
output = finalize_outputs(prompt,
conv_rendered,
msg;
conversation,
return_all,
dry_run,
kwargs...)
return output
end
function aiembed(prompt_schema::AbstractGoogleSchema, prompt::ALLOWED_PROMPT_TYPE;
kwargs...)
error("Google schema does not yet support aiembed. Please use OpenAISchema instead.")
end
function aiclassify(prompt_schema::AbstractGoogleSchema, prompt::ALLOWED_PROMPT_TYPE;
kwargs...)
error("Google schema does not yet support aiclassify. Please use OpenAISchema instead.")
end
function aiextract(prompt_schema::AbstractGoogleSchema, prompt::ALLOWED_PROMPT_TYPE;
kwargs...)
error("Google schema does not yet support aiextract. Please use OpenAISchema instead.")
end
function aiscan(prompt_schema::AbstractGoogleSchema, prompt::ALLOWED_PROMPT_TYPE;
kwargs...)
error("Google schema does not yet support aiscan. Please use OpenAISchema instead.")
end
function aiimage(prompt_schema::AbstractGoogleSchema, prompt::ALLOWED_PROMPT_TYPE;
kwargs...)
error("Google schema does not yet support aiimage. Please use OpenAISchema instead.")
end