diff --git a/chat/tools/tools.go b/chat/tools/tools.go index 62fdb92..a2fba1c 100644 --- a/chat/tools/tools.go +++ b/chat/tools/tools.go @@ -3,19 +3,25 @@ package tools import ( "encoding/json" "fmt" + "reflect" ) func NewTool[T any](name, description string, f func(T) (string, error)) (tool, error) { inputType := ensureInputStructType[T]() - wrapper := func(input string) (string, error) { - raw := []byte(input) + var extract func(reflect.Value) T + if inputType == reflect.TypeFor[T]() { + extract = func(v reflect.Value) T { return v.Interface().(T) } + } else { + extract = func(v reflect.Value) T { return v.Field(0).Interface().(T) } + } - var parsed T - if err := json.Unmarshal(raw, &parsed); err != nil { - return "", fmt.Errorf("unmarshal into %T: %w", parsed, err) + wrapper := func(input string) (string, error) { + ptr := reflect.New(inputType) + if err := json.Unmarshal([]byte(input), ptr.Interface()); err != nil { + return "", fmt.Errorf("unmarshal into %v: %w", inputType, err) } - return f(parsed) + return f(extract(ptr.Elem())) } t := tool{ diff --git a/chat/tools/tools_test.go b/chat/tools/tools_test.go index ce4cb06..0cdea5b 100644 --- a/chat/tools/tools_test.go +++ b/chat/tools/tools_test.go @@ -53,6 +53,9 @@ type RecursiveStruct struct { func TestNewTool_WithStruct(t *testing.T) { tool, err := NewTool("test_struct", "test description", func(s SimpleStruct) (string, error) { + if s.Name != "John" || s.Age != 30 { + return "", errors.New("invalid input") + } return "ok", nil }) @@ -72,6 +75,23 @@ func TestNewTool_WithStruct(t *testing.T) { if tool.schema == nil { t.Error("NewTool() schema is nil") } + + tools := NewTools() + err = tools.Add(tool) + if err != nil { + t.Errorf("Add() error = %v", err) + return + } + + result, ok := tools.Execute("test_struct", `{"name": "John", "age": 30}`) + if !ok { + t.Errorf("Execute() error = %v", result) + return + } + + if result != "ok" { + t.Errorf("Execute() result = %v, want %v", result, "ok") + } } func TestNewTool_WithPrimitive(t *testing.T) { @@ -102,15 +122,31 @@ func TestNewTool_WithPrimitive(t *testing.T) { if !strings.Contains(string(schemaJSON), "input") { t.Errorf("NewTool() schema should contain 'input' field, got: %s", schemaJSON) } + + tools := NewTools() + err = tools.Add(tool) + if err != nil { + t.Errorf("Add() error = %v", err) + return + } + + // Try to execute with valid input + result, ok := tools.Execute("test_primitive", `{"input": "hello"}`) + if !ok { + t.Errorf("Execute() error = %v", result) + return + } } func TestNewTool_NestedStruct(t *testing.T) { tool, err := NewTool("test_nested", "nested struct test", func(n NestedStruct) (string, error) { + if n.User.Name != "John" || n.User.Age != 30 || n.Active != true { + return "", errors.New("invalid input") + } return "ok", nil }) if err != nil { - t.Errorf("NewTool() error = %v", err) return } @@ -128,10 +164,28 @@ func TestNewTool_NestedStruct(t *testing.T) { if !strings.Contains(string(schemaJSON), "user") || !strings.Contains(string(schemaJSON), "active") { t.Errorf("NewTool() schema should contain 'user' and 'active' fields, got: %s", schemaJSON) } + + tools := NewTools() + err = tools.Add(tool) + if err != nil { + t.Errorf("Add() error = %v", err) + return + } + + // Try to execute with valid input + result, ok := tools.Execute("test_nested", `{"user": {"name": "John", "age": 30}, "active": true}`) + t.Log("result: ", result) + if !ok { + t.Errorf("Execute() error = %v", result) + return + } } func TestNewTool_EmbeddedStruct(t *testing.T) { tool, err := NewTool("test_embedded", "embedded struct test", func(e EmbeddedStruct) (string, error) { + if e.Name != "John" || e.Age != 30 { + return "", errors.New("invalid input") + } return "ok", nil }) @@ -154,6 +208,21 @@ func TestNewTool_EmbeddedStruct(t *testing.T) { if !strings.Contains(string(schemaJSON), "name") || !strings.Contains(string(schemaJSON), "age") { t.Errorf("NewTool() schema should contain embedded fields 'name' and 'age', got: %s", schemaJSON) } + + tools := NewTools() + err = tools.Add(tool) + if err != nil { + t.Errorf("Add() error = %v", err) + return + } + + // Try to execute with valid input + result, ok := tools.Execute("test_embedded", `{"name": "John", "age": 30}`) + t.Log("result: ", result) + if !ok { + t.Errorf("Execute() error = %v", result) + return + } } func TestNewTool_PrimitiveWithMapPrimitive(t *testing.T) { @@ -162,6 +231,11 @@ func TestNewTool_PrimitiveWithMapPrimitive(t *testing.T) { } tool, err := NewTool("test_map_primitive", "map primitive test", func(m MapPrimitive) (string, error) { + scores := m.Scores + if scores["a"] != 1 || scores["b"] != 2 { + t.Errorf("invalid scores: %v", m) + return "", errors.New("invalid scores") + } return "ok", nil }) @@ -183,10 +257,29 @@ func TestNewTool_PrimitiveWithMapPrimitive(t *testing.T) { if !strings.Contains(string(schemaJSON), "scores") { t.Errorf("NewTool() schema should contain 'scores' field, got: %s", schemaJSON) } + + tools := NewTools() + err = tools.Add(tool) + if err != nil { + t.Errorf("Add() error = %v", err) + return + } + + // Try to execute with valid input + result, ok := tools.Execute("test_map_primitive", `{"scores": {"a": 1, "b": 2}}`) + t.Log("result: ", result) + if !ok { + t.Errorf("Execute() error = %v", result) + return + } } func TestNewTool_PrimitiveWithMapStruct(t *testing.T) { tool, err := NewTool("test_map_struct", "map struct test", func(m MapStruct) (string, error) { + users := m.Users + if users["alice"].Age != 25 || users["bob"].Age != 30 { + return "", errors.New("invalid users") + } return "ok", nil }) @@ -208,10 +301,29 @@ func TestNewTool_PrimitiveWithMapStruct(t *testing.T) { if !strings.Contains(string(schemaJSON), "users") { t.Errorf("NewTool() schema should contain 'users' field, got: %s", schemaJSON) } + + tools := NewTools() + err = tools.Add(tool) + if err != nil { + t.Errorf("Add() error = %v", err) + return + } + + // Try to execute with valid input + result, ok := tools.Execute("test_map_struct", `{"users": {"alice": {"age": 25}, "bob": {"age": 30}}}`) + t.Log("result: ", result) + if !ok { + t.Errorf("Execute() error = %v", result) + return + } } func TestNewTool_PrimitiveWithSlicePrimitive(t *testing.T) { tool, err := NewTool("test_slice_primitive", "slice primitive test", func(s SlicePrimitive) (string, error) { + ids := s.IDs + if len(ids) != 3 || ids[0] != 1 || ids[1] != 2 || ids[2] != 3 { + return "", errors.New("invalid ids") + } return "ok", nil }) @@ -233,10 +345,29 @@ func TestNewTool_PrimitiveWithSlicePrimitive(t *testing.T) { if !strings.Contains(string(schemaJSON), "ids") { t.Errorf("NewTool() schema should contain 'ids' field, got: %s", schemaJSON) } + + tools := NewTools() + err = tools.Add(tool) + if err != nil { + t.Errorf("Add() error = %v", err) + return + } + + // Try to execute with valid input + result, ok := tools.Execute("test_slice_primitive", `{"ids": [1, 2, 3]}`) + t.Log("result: ", result) + if !ok { + t.Errorf("Execute() error = %v", result) + return + } } func TestNewTool_PrimitiveWithSliceStruct(t *testing.T) { tool, err := NewTool("test_slice_struct", "slice struct test", func(s SliceStruct) (string, error) { + items := s.Items + if len(items) != 2 || items[0].Name != "Alice" || items[0].Age != 25 || items[1].Name != "Bob" || items[1].Age != 30 { + return "", errors.New("invalid items") + } return "ok", nil }) @@ -258,6 +389,20 @@ func TestNewTool_PrimitiveWithSliceStruct(t *testing.T) { if !strings.Contains(string(schemaJSON), "items") { t.Errorf("NewTool() schema should contain 'items' field, got: %s", schemaJSON) } + + tools := NewTools() + err = tools.Add(tool) + if err != nil { + t.Errorf("Add() error = %v", err) + return + } + + result, ok := tools.Execute("test_slice_struct", `{"items": [{"name": "Alice", "age": 25}, {"name": "Bob", "age": 30}]}`) + t.Log("result: ", result) + if !ok { + t.Errorf("Execute() error = %v", result) + return + } } func TestNewTool_PrimitiveWithMap(t *testing.T) { @@ -267,6 +412,10 @@ func TestNewTool_PrimitiveWithMap(t *testing.T) { } tool, err := NewTool("test_map", "map test", func(m MapInput) (string, error) { + val, ok := m.Data["key"] + if !ok || val != "value" { + return "", errors.New("invalid data") + } return "ok", nil }) @@ -278,10 +427,33 @@ func TestNewTool_PrimitiveWithMap(t *testing.T) { if tool.schema == nil { t.Error("NewTool() schema is nil") } + + tools := NewTools() + err = tools.Add(tool) + if err != nil { + t.Errorf("Add() error = %v", err) + return + } + + result, ok := tools.Execute("test_map", `{"data": {"key": "value"}}`) + if !ok { + t.Errorf("Execute() error = %v", result) + return + } + + if result != "ok" { + t.Errorf("Execute() result = %v, want %v", result, "ok") + } } func TestNewTool_RecursiveStruct(t *testing.T) { tool, err := NewTool("test_recursive", "recursive struct test", func(r RecursiveStruct) (string, error) { + if r.Value != 42 { + return "", errors.New("invalid value") + } + if r.Child == nil || r.Child.Value != 100 { + return "", errors.New("invalid child") + } return "ok", nil }) @@ -303,6 +475,23 @@ func TestNewTool_RecursiveStruct(t *testing.T) { if !strings.Contains(string(schemaJSON), "value") || !strings.Contains(string(schemaJSON), "child") { t.Errorf("NewTool() schema should contain 'value' and 'child' fields, got: %s", schemaJSON) } + + tools := NewTools() + err = tools.Add(tool) + if err != nil { + t.Errorf("Add() error = %v", err) + return + } + + result, ok := tools.Execute("test_recursive", `{"value": 42, "child": {"value": 100}}`) + if !ok { + t.Errorf("Execute() error = %v", result) + return + } + + if result != "ok" { + t.Errorf("Execute() result = %v, want %v", result, "ok") + } } // ============================================================================ @@ -563,6 +752,9 @@ func TestNewTool_PointerType(t *testing.T) { } tool, err := NewTool("pointer_test", "test pointer type", func(s PointerStruct) (string, error) { + if s.Name == nil || *s.Name != "John" { + return "", errors.New("invalid name") + } return "ok", nil }) @@ -574,6 +766,23 @@ func TestNewTool_PointerType(t *testing.T) { if tool.schema == nil { t.Error("NewTool() schema is nil for pointer type") } + + tools := NewTools() + err = tools.Add(tool) + if err != nil { + t.Errorf("Add() error = %v", err) + return + } + + result, ok := tools.Execute("pointer_test", `{"name": "John"}`) + if !ok { + t.Errorf("Execute() error = %v", result) + return + } + + if result != "ok" { + t.Errorf("Execute() result = %v, want %v", result, "ok") + } } func TestNewTool_InterfaceType(t *testing.T) { @@ -582,6 +791,9 @@ func TestNewTool_InterfaceType(t *testing.T) { } tool, err := NewTool("interface_test", "test interface type", func(s InterfaceStruct) (string, error) { + if s.Data == nil { + return "", errors.New("data is nil") + } return "ok", nil }) @@ -593,6 +805,23 @@ func TestNewTool_InterfaceType(t *testing.T) { if tool.schema == nil { t.Error("NewTool() schema is nil for interface type") } + + tools := NewTools() + err = tools.Add(tool) + if err != nil { + t.Errorf("Add() error = %v", err) + return + } + + result, ok := tools.Execute("interface_test", `{"data": "test"}`) + if !ok { + t.Errorf("Execute() error = %v", result) + return + } + + if result != "ok" { + t.Errorf("Execute() result = %v, want %v", result, "ok") + } } func TestNewTool_AnonymousStruct(t *testing.T) { @@ -600,6 +829,9 @@ func TestNewTool_AnonymousStruct(t *testing.T) { tool, err := NewTool("anon_struct", "test anonymous struct", func(s struct { Name string `json:"name"` }) (string, error) { + if s.Name != "John" { + return "", errors.New("invalid name") + } return "ok", nil }) @@ -611,11 +843,31 @@ func TestNewTool_AnonymousStruct(t *testing.T) { if tool.schema == nil { t.Error("NewTool() schema is nil for anonymous struct") } + + tools := NewTools() + err = tools.Add(tool) + if err != nil { + t.Errorf("Add() error = %v", err) + return + } + + result, ok := tools.Execute("anon_struct", `{"name": "John"}`) + if !ok { + t.Errorf("Execute() error = %v", result) + return + } + + if result != "ok" { + t.Errorf("Execute() result = %v, want %v", result, "ok") + } } func TestNewTool_WithIntPrimitive(t *testing.T) { tool, err := NewTool("int_primitive", "test int primitive", func(n int) (string, error) { - return "received", nil + if n != 42 { + return "", errors.New("invalid value") + } + return "ok", nil }) if err != nil { @@ -632,11 +884,31 @@ func TestNewTool_WithIntPrimitive(t *testing.T) { if !strings.Contains(string(schemaJSON), "input") { t.Errorf("NewTool() schema should contain 'input' field, got: %s", schemaJSON) } + + tools := NewTools() + err = tools.Add(tool) + if err != nil { + t.Errorf("Add() error = %v", err) + return + } + + result, ok := tools.Execute("int_primitive", `{"input": 42}`) + if !ok { + t.Errorf("Execute() error = %v", result) + return + } + + if result != "ok" { + t.Errorf("Execute() result = %v, want %v", result, "ok") + } } func TestNewTool_WithBoolPrimitive(t *testing.T) { tool, err := NewTool("bool_primitive", "test bool primitive", func(b bool) (string, error) { - return "received", nil + if b != true { + return "", errors.New("invalid value") + } + return "ok", nil }) if err != nil { @@ -647,6 +919,23 @@ func TestNewTool_WithBoolPrimitive(t *testing.T) { if tool.schema == nil { t.Error("NewTool() schema is nil for bool primitive") } + + tools := NewTools() + err = tools.Add(tool) + if err != nil { + t.Errorf("Add() error = %v", err) + return + } + + result, ok := tools.Execute("bool_primitive", `{"input": true}`) + if !ok { + t.Errorf("Execute() error = %v", result) + return + } + + if result != "ok" { + t.Errorf("Execute() result = %v, want %v", result, "ok") + } } // ============================================================================ @@ -729,7 +1018,10 @@ func containsHelper(s, substr string) bool { func TestNewTool_WithFloatPrimitive(t *testing.T) { tool, err := NewTool("float_primitive", "test float64 primitive", func(f float64) (string, error) { - return "received", nil + if f != 3.14 { + return "", errors.New("invalid value") + } + return "ok", nil }) if err != nil { @@ -740,4 +1032,21 @@ func TestNewTool_WithFloatPrimitive(t *testing.T) { if tool.schema == nil { t.Error("NewTool() schema is nil for float64 primitive") } + + tools := NewTools() + err = tools.Add(tool) + if err != nil { + t.Errorf("Add() error = %v", err) + return + } + + result, ok := tools.Execute("float_primitive", `{"input": 3.14}`) + if !ok { + t.Errorf("Execute() error = %v", result) + return + } + + if result != "ok" { + t.Errorf("Execute() result = %v, want %v", result, "ok") + } }