Skip to content

Commit

Permalink
[FR] Add support for structured extraction with Ollama models
Browse files Browse the repository at this point in the history
Fixes #68
  • Loading branch information
cpfiffer committed May 5, 2024
1 parent 1cda053 commit 97c9733
Show file tree
Hide file tree
Showing 2 changed files with 255 additions and 3 deletions.
132 changes: 131 additions & 1 deletion src/extraction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,24 @@
# There are potential formats: 1) JSON-based for OpenAI compatible APIs, 2) XML-based for Anthropic compatible APIs (used also by Hermes-2-Pro model).
#

"""
JSON_PRIMITIVE_TYPES
A set of primitive types that are supported by JSON. If a type
is not in this set, the JSON typer [`to_json_type`](@ref) will
assume that the type is a `struct` and will attempt to recursively
unpack the fields of the struct.
"""
const JSON_PRIMITIVE_TYPES = Union{
Integer,
Real,
AbstractString,
Bool,
Nothing,
Missing,
AbstractArray
}

######################
# 1) OpenAI / JSON format
######################
Expand All @@ -15,7 +33,14 @@ to_json_type(n::Type{<:Real}) = "number"
to_json_type(n::Type{<:Integer}) = "integer"
to_json_type(b::Type{Bool}) = "boolean"
to_json_type(t::Type{<:Union{Missing, Nothing}}) = "null"
to_json_type(t::Type{<:Any}) = "string" # object?
to_json_type(t::Type{T}) where {T <: AbstractArray} = to_json_type(eltype(t)) * "[]"
to_json_type(t::Type{Any}) = throw(ArgumentError("""
Type $t is not a valid type for to_json_type. Please provide a valid type found in:
$JSON_PRIMITIVE_TYPES
You may be using to_json_schema but forgot to properly type the fields of your struct.
"""))

has_null_type(T::Type{Missing}) = true
has_null_type(T::Type{Nothing}) = true
Expand Down Expand Up @@ -236,3 +261,108 @@ Extract zero, one or more specified items from the provided data.
struct ItemsExtract{T <: Any}
items::Vector{T}
end

"""
typed_json_schema(x::Type{T}) where {T}
Convert a Julia type to a JSON schema that lists keys as field names and values as
the types of those field names.
WARNING! Every field in your struct, and all nested structs, must be typed using a subtype of values in [`JSON_PRIMITIVE_TYPES`](@ref)
before calling this function. Otherwise, you will get a recursion error.
## Example
```julia
# Simple flat structure where each field is a primitive type
struct SimpleSingleton
singleton_value::Int
end
typed_json_schema(SimpleSingleton)
```
```
Dict{Any, Any} with 1 entry:
:singleton_value => "integer"
```
Or using nested structs
```julia
# Test a struct that contains another struct.
struct Nested
inside_element::SimpleSingleton
end
typed_json_schema(Nested)
```
```julia
Dict{Any, Any} with 1 entry:
:inside_element => Dict{Any, Any}("singleton_value" => "integer")
```
Lists of created Julia types will be specified as `List[Object]` with the value being the type of the elements,
i.e.
```julia
# Test a struct with a vector of primitives
struct ABunchOfVectors
strings::Vector{String}
ints::Vector{Int}
floats::Vector{Float64}
nested_vector::Vector{Nested}
end
typed_json_schema(ABunchOfVectors)
```
```
Dict{Any, Any} with 4 entries:
:strings => "string[]"
:ints => "integer[]"
:nested_vector => Dict("list[Object]"=>"{\"inside_element\":{\"singleton_value\":\"integer\"}}")
:floats => "number[]"
```
## Resources
- the [original issue](https://github.com/svilupp/PromptingTools.jl/issues/143)
- the [motivation](https://www.boundaryml.com/blog/type-definition-prompting-baml)
"""
function typed_json_schema(x::Type{T}) where {T}
# We can return early if the type is a non-array primitive
if T <: JSON_PRIMITIVE_TYPES && !(T <: AbstractArray)
return to_json_type(T)
end

# If there are no fields, return the type
if isempty(fieldnames(T))
# Check if this is a vector type. If so, return the type of the elements.
if T <: AbstractArray
# Now check if the element type is a non-primitive. If so, recursively call typed_json_schema.
if eltype(T) <: JSON_PRIMITIVE_TYPES
return to_json_type(T)
else
return Dict("list[Object]" => JSON3.write(typed_json_schema(eltype(T))))
# return "List[" * JSON3.write(typed_json_schema(eltype(T))) * "]"
end
end

# Check if the type is a non-primitive.
if T <: JSON_PRIMITIVE_TYPES
return to_json_type(T)
else
return typed_json_schema(T)
end
end

# Preallocate a mapping
mapping = Dict()
for (type, field) in zip(T.types, fieldnames(T))
mapping[field] = typed_json_schema(type)
end

# Get property names
return mapping
end
126 changes: 124 additions & 2 deletions test/extraction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -225,9 +225,11 @@ end
end
output = function_call_signature(MyMeasurement2)#|> JSON3.pretty
expected_output = Dict{String, Any}("name" => "MyMeasurement2_extractor",
"parameters" => Dict{String, Any}("properties" => Dict{String, Any}("height" => Dict{
"parameters" => Dict{String, Any}(
"properties" => Dict{String, Any}(
"height" => Dict{
String,
Any,
Any
}("type" => "integer"),
"weight" => Dict{String, Any}("type" => "number"),
"age" => Dict{String, Any}("type" => "integer")),
Expand All @@ -240,3 +242,123 @@ end
schema = function_call_signature(MaybeExtract{MyMeasurement2})
@test schema["name"] == "MaybeExtractMyMeasurement2_extractor"
end
@testset "to_json_schema-primitive_types" begin
@test to_json_schema(Int) == Dict("type" => "integer")
@test to_json_schema(Float64) == Dict("type" => "number")
@test to_json_schema(Bool) == Dict("type" => "boolean")
@test to_json_schema(String) == Dict("type" => "string")
@test_throws ArgumentError to_json_schema(Any) # Type Any is not supported
end
@testset "to_json_schema-structs" begin
# Function to check the equivalence of two JSON strings, since Dict is
# unordered, we need to sort keys before comparison.
function check_json_equivalence(json1::AbstractString, json2::AbstractString)
println("\ncheck_json_equivalence\n===json1===")
println(json1)
println("===json2===")
println(json2)
println()
# JSON dictionary
d1 = JSON3.read(json1)
d2 = JSON3.read(json2)

# Get all the keys
k1 = sort(collect(keys(d1)))
k2 = sort(collect(keys(d2)))

# Test that all the keys are present
@test setdiff(k1, k2) == []
@test setdiff(k2, k1) == []

# Test that all the values are equivalent
for (k, v) in d1
@test d2[k] == v
end

# @test JSON3.write(JSON3.read(json1)) == JSON3.write(JSON3.read(json2))
end
function check_json_equivalence(d::Dict, s::AbstractString)
return check_json_equivalence(JSON3.write(d), s)
end

# Simple flat structure where each field is a primitive type
struct SimpleSingleton
singleton_value::Int
end

check_json_equivalence(
JSON3.write(typed_json_schema(SimpleSingleton)),
"{\"singleton_value\":\"integer\"}"
)

# Test a struct that contains another struct.
struct Nested
inside_element::SimpleSingleton
end

check_json_equivalence(
JSON3.write(typed_json_schema(Nested)),
"{\"inside_element\":{\"singleton_value\":\"integer\"}}"
)

# Test a struct with two primitive types
struct IntFloatFlat
int_value::Int
float_value::Float64
end
check_json_equivalence(
typed_json_schema(IntFloatFlat),
"{\"int_value\":\"integer\",\"float_value\":\"number\"}"
)

# Test a struct that contains all primitive types
struct AllJSONPrimitives
int::Integer
float::Real
string::AbstractString
bool::Bool
nothing::Nothing
missing::Missing

# Array types
array_of_strings::Vector{String}
array_of_ints::Vector{Int}
array_of_floats::Vector{Float64}
array_of_bools::Vector{Bool}
array_of_nothings::Vector{Nothing}
array_of_missings::Vector{Missing}
end

check_json_equivalence(
typed_json_schema(AllJSONPrimitives),
"{\"int\":\"integer\",\"float\":\"number\",\"string\":\"string\",\"bool\":\"boolean\",\"nothing\":\"null\",\"missing\":\"null\",\"array_of_strings\":\"string[]\",\"array_of_ints\":\"integer[]\",\"array_of_floats\":\"number[]\",\"array_of_bools\":\"boolean[]\",\"array_of_nothings\":\"null[]\",\"array_of_missings\":\"null[]\"}"
)

# Test a struct with a vector of primitives
struct ABunchOfVectors
strings::Vector{String}
ints::Vector{Int}
floats::Vector{Float64}
nested_vector::Vector{Nested}
end

check_json_equivalence(
typed_json_schema(ABunchOfVectors),
"{\"strings\":\"string[]\",\"ints\":\"integer[]\",\"nested_vector\":{\"list[Object]\":\"{\\\"inside_element\\\":{\\\"singleton_value\\\":\\\"integer\\\"}}\"},\"floats\":\"number[]\"}"
)

# Weird struct with a bunch of different types
struct Monster
name::String
age::Int
height::Float64
friends::Vector{String}
nested::Nested
flat::IntFloatFlat
end

check_json_equivalence(
typed_json_schema(Monster),
"{\"flat\":{\"float_value\":\"number\",\"int_value\":\"integer\"},\"nested\":{\"inside_element\":{\"singleton_value\":\"integer\"}},\"age\":\"integer\",\"name\":\"string\",\"height\":\"number\",\"friends\":\"string[]\"}"
)
end;

0 comments on commit 97c9733

Please sign in to comment.