Skip to content

Commit

Permalink
Merge pull request #28 from Naist4869/patch-1
Browse files Browse the repository at this point in the history
Update functioncall.go
  • Loading branch information
otiai10 committed Sep 11, 2023
2 parents db9573c + 187a75b commit 8661b0c
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 6 deletions.
111 changes: 105 additions & 6 deletions functioncall/all_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package functioncall

import (
"encoding/json"
"reflect"
"testing"

. "github.com/otiai10/mint"
Expand All @@ -20,9 +21,9 @@ func TestFunctions_MarshalJSON(t *testing.T) {
return r
}
funcs := Funcs{
"repeat": Func{repeat, "Repeat given string N times", Params{
{"word", "string", "String to be repeated", true},
{"count", "number", "How many times to repeat", true},
"repeat": Func{Value: repeat, Description: "Repeat given string N times", Parameters: Params{
{Name: "word", Type: "string", Description: "String to be repeated", Required: true},
{Name: "count", Type: "number", Description: "How many times to repeat", Required: true},
}},
}
b, err := funcs.MarshalJSON()
Expand All @@ -47,13 +48,111 @@ func TestAs(t *testing.T) {
return r
}
funcs := Funcs{
"repeat": Func{repeat, "Repeat given string N times", Params{
{"word", "string", "String to be repeated", true},
{"count", "number", "How many times to repeat", true},
"repeat": Func{Value: repeat, Description: "Repeat given string N times", Parameters: Params{
{Name: "word", Type: "string", Description: "String to be repeated", Required: true},
{Name: "count", Type: "number", Description: "How many times to repeat", Required: true},
}},
}
a := As[[]map[string]any](funcs)
Expect(t, a).TypeOf("[]map[string]interface {}")
Expect(t, a).Query("0.name").ToBe("repeat")
Expect(t, a).Query("0.parameters.type").ToBe("object")
}

func TestParams_MarshalJSON(t *testing.T) {
tests := []struct {
name string
params Params
want []byte
wantErr bool
}{
{
name: "nested",
params: []Param{
{
Name: "quality",
Type: "object",
Description: "",
Required: true,
Items: []Param{
{
Name: "pros",
Type: "array",
Description: "Write 3 points why this text is well written",
Required: true,
Items: []Param{
{Type: "string"},
},
},
},
},
},
want: []byte(`{"properties":{"quality":{"properties":{"pros":{"description":"Write 3 points why this text is well written","items":{"type":"string"},"type":"array"}},"required":["pros"],"type":"object"}},"required":["quality"],"type":"object"}`),
wantErr: false,
},
{
name: "nested_example",
params: []Param{
{
Name: "ingredients",
Type: "array",
Required: true,
Items: []Param{
{
Type: "object",
Items: []Param{
{
Name: "name",
Type: "string",
Required: true,
},
{
Name: "unit",
Type: "string",
// Enum: []any{"grams", "ml", "cups", "pieces", "teaspoons"},
Required: true,
},
{
Name: "amount",
Type: "number",
Required: true,
},
},
},
},
},
{
Name: "instructions",
Type: "array",
Required: true,
Items: []Param{
{
Type: "string",
},
},
Description: "Steps to prepare the recipe (no numbering)",
},
{
Name: "time_to_cook",
Type: "number",
Description: "Total time to prepare the recipe in minutes",
Required: true,
},
},
want: []byte(`{"properties":{"ingredients":{"items":{"properties":{"amount":{"type":"number"},"name":{"type":"string"},"unit":{"type":"string"}},"required":["name","unit","amount"],"type":"object"},"type":"array"},"instructions":{"description":"Steps to prepare the recipe (no numbering)","items":{"type":"string"},"type":"array"},"time_to_cook":{"type":"number","description":"Total time to prepare the recipe in minutes"}},"required":["ingredients","instructions","time_to_cook"],"type":"object"}`),
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := json.Marshal(tt.params)
if (err != nil) != tt.wantErr {
t.Errorf("MarshalJSON() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("MarshalJSON() got = %s, want %s", got, tt.want)
}
})
}
}
35 changes: 35 additions & 0 deletions functioncall/functioncall.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,15 @@ type Func struct {

type Params []Param

type NestedParams []Param

type Param struct {
Name string `json:"-"`
Type string `json:"type,omitempty"`
Description string `json:"description,omitempty"`
Required bool `json:"-"`
// Enum []any `json:"enum,omitempty"`
Items NestedParams `json:",omitempty"`
}

func (funcs Funcs) MarshalJSON() ([]byte, error) {
Expand All @@ -37,6 +40,18 @@ func (funcs Funcs) MarshalJSON() ([]byte, error) {
}

func (params Params) MarshalJSON() ([]byte, error) {
return marshalObject(params)
}

func (params NestedParams) MarshalJSON() ([]byte, error) {
if len(params) == 1 {
return json.Marshal(params[0])
}

return marshalObject(params)
}

func marshalObject[T ~[]Param](params T) ([]byte, error) {
required := []string{}
props := map[string]Param{}
for _, p := range params {
Expand All @@ -45,6 +60,7 @@ func (params Params) MarshalJSON() ([]byte, error) {
}
props[p.Name] = p
}

schema := map[string]any{
"type": "object",
"properties": props,
Expand All @@ -53,6 +69,25 @@ func (params Params) MarshalJSON() ([]byte, error) {
return json.Marshal(schema)
}

func (param Param) MarshalJSON() ([]byte, error) {
switch param.Type {
case "array":
schema := map[string]any{
"type": "array",
"items": param.Items,
}
if param.Description != "" {
schema["description"] = param.Description
}
return json.Marshal(schema)
case "object":
return marshalObject(param.Items)
default:
type Alias Param
return json.Marshal(Alias(param))
}
}

func As[T any](funcs Funcs) (dest T) {
b, err := funcs.MarshalJSON()
if err != nil {
Expand Down

0 comments on commit 8661b0c

Please sign in to comment.