-
Notifications
You must be signed in to change notification settings - Fork 8
/
llm_shared.jl
134 lines (123 loc) · 6.41 KB
/
llm_shared.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
# Reusable functionality across different schemas
function role4render(schema::AbstractPromptSchema, msg::AbstractMessage)
throw(ArgumentError("Function `role4render` is not implemented for the provided schema ($(typeof(schema))) and $(typeof(msg))."))
end
role4render(schema::AbstractPromptSchema, msg::SystemMessage) = "system"
role4render(schema::AbstractPromptSchema, msg::UserMessage) = "user"
role4render(schema::AbstractPromptSchema, msg::UserMessageWithImages) = "user"
role4render(schema::AbstractPromptSchema, msg::AIMessage) = "assistant"
"""
render(schema::NoSchema,
messages::Vector{<:AbstractMessage};
conversation::AbstractVector{<:AbstractMessage} = AbstractMessage[],
replacement_kwargs...)
Renders a conversation history from a vector of messages with all replacement variables specified in `replacement_kwargs`.
It is the first pass of the prompt rendering system, and is used by all other schemas.
# Keyword Arguments
- `image_detail`: Only for `UserMessageWithImages`. It represents the level of detail to include for images. Can be `"auto"`, `"high"`, or `"low"`.
- `conversation`: An optional vector of `AbstractMessage` objects representing the conversation history. If not provided, it is initialized as an empty vector.
# Notes
- All unspecified kwargs are passed as replacements such that `{{key}}=>value` in the template.
- If a SystemMessage is missing, we inject a default one at the beginning of the conversation.
- Only one SystemMessage is allowed (ie, cannot mix two conversations different system prompts).
"""
function render(schema::NoSchema,
messages::Vector{<:AbstractMessage};
conversation::AbstractVector{<:AbstractMessage} = AbstractMessage[],
replacement_kwargs...)
## copy the conversation to avoid mutating the original
conversation = copy(conversation)
count_system_msg = count(issystemmessage, conversation)
# TODO: concat multiple system messages together (2nd pass)
# replace any handlebar variables in the messages
for msg in messages
if msg isa Union{SystemMessage, UserMessage, UserMessageWithImages}
replacements = ["{{$(key)}}" => value
for (key, value) in pairs(replacement_kwargs)
if key in msg.variables]
# Rebuild the message with the replaced content
MSGTYPE = typeof(msg)
new_msg = MSGTYPE(;
# unpack the type to replace only the content field
[(field, getfield(msg, field)) for field in fieldnames(typeof(msg))]...,
content = replace(msg.content, replacements...))
if msg isa SystemMessage
count_system_msg += 1
# move to the front
pushfirst!(conversation, new_msg)
else
push!(conversation, new_msg)
end
elseif msg isa AIMessage
# no replacements
push!(conversation, msg)
elseif istracermessage(msg) && issystemmessage(msg.object)
# Look for tracers
count_system_msg += 1
# move to the front
pushfirst!(conversation, msg)
else
# Note: Ignores any DataMessage or other types for the prompt/conversation history
@warn "Unexpected message type: $(typeof(msg)). Skipping."
end
end
## Multiple system prompts are not allowed
(count_system_msg > 1) && throw(ArgumentError("Only one system message is allowed."))
## Add default system prompt if not provided
(count_system_msg == 0) && pushfirst!(conversation,
SystemMessage("Act as a helpful AI assistant"))
return conversation
end
"""
finalize_outputs(prompt::ALLOWED_PROMPT_TYPE, conv_rendered::Any,
msg::Union{Nothing, AbstractMessage, AbstractVector{<:AbstractMessage}};
return_all::Bool = false,
dry_run::Bool = false,
conversation::AbstractVector{<:AbstractMessage} = AbstractMessage[],
kwargs...)
Finalizes the outputs of the ai* functions by either returning the conversation history or the last message.
# Keyword arguments
- `return_all::Bool=false`: If true, returns the entire conversation history, otherwise returns only the last message (the `AIMessage`).
- `dry_run::Bool=false`: If true, does not send the messages to the model, but only renders the prompt with the given schema and replacement variables.
Useful for debugging when you want to check the specific schema rendering.
- `conversation::AbstractVector{<:AbstractMessage}=[]`: An optional vector of `AbstractMessage` objects representing the conversation history. If not provided, it is initialized as an empty vector.
- `kwargs...`: Variables to replace in the prompt template.
"""
function finalize_outputs(prompt::ALLOWED_PROMPT_TYPE, conv_rendered::Any,
msg::Union{Nothing, AbstractMessage, AbstractVector{<:AbstractMessage}};
return_all::Bool = false,
dry_run::Bool = false,
conversation::AbstractVector{<:AbstractMessage} = AbstractMessage[],
kwargs...)
if return_all
if !dry_run
# If not a dry_run, re-create the messages sent to the model before schema application
# This is a duplication of work, as we already have the rendered messages in conv_rendered,
# but we prioritize the user's experience over performance here (ie, render(OpenAISchema,msgs) does everything under the hood)
output = render(NoSchema(), prompt; conversation, kwargs...)
if msg isa AbstractVector
## handle multiple messages (multi-sample)
append!(output, msg)
else
push!(output, msg)
end
else
output = conv_rendered
end
return output
else
return msg
end
end
## Helpers for aiclassify -> they encode the choice list to create the prompt and then extract the original choice category
function encode_choices(schema::AbstractPromptSchema,
choices;
kwargs...)
throw(ArgumentError("Function `encode_choices` is not implemented for the provided schema ($schema) and $(choices)."))
end
function decode_choices(schema::AbstractPromptSchema, choices, conv; kwargs...)
throw(ArgumentError("Function `decode_choices` is not implemented for the provided schema ($schema) and $(choices)."))
end
function decode_choices(schema::AbstractPromptSchema, choices, conv::Nothing; kwargs...)
nothing
end