diff --git a/backend/pkg/providers/bedrock/bedrock.go b/backend/pkg/providers/bedrock/bedrock.go index 6e569617..9e0f17a7 100644 --- a/backend/pkg/providers/bedrock/bedrock.go +++ b/backend/pkg/providers/bedrock/bedrock.go @@ -168,11 +168,24 @@ func (p *bedrockProvider) CallEx( chain []llms.MessageContent, streamCb streaming.Callback, ) (*llms.ContentResponse, error) { + options := []llms.CallOption{ + llms.WithStreamingFunc(streamCb), + } + + // The AWS Bedrock Converse API requires toolConfig to be defined whenever the + // conversation history contains toolUse or toolResult content blocks — even when + // no new tools are being offered in the current turn. Without it the API returns: + // ValidationException: The toolConfig field must be defined when using + // toolUse and toolResult content blocks. + // We reconstruct minimal tool definitions from the tool-call names already + // present in the chain so that the library sets toolConfig automatically. + if minimalTools := buildMinimalToolsFromChain(chain); len(minimalTools) > 0 { + options = append(options, llms.WithTools(minimalTools)) + } + return provider.WrapGenerateContent( ctx, p, opt, p.llm.GenerateContent, chain, - append([]llms.CallOption{ - llms.WithStreamingFunc(streamCb), - }, p.providerConfig.GetOptionsForType(opt)...)..., + append(options, p.providerConfig.GetOptionsForType(opt)...)..., ) } @@ -183,6 +196,14 @@ func (p *bedrockProvider) CallWithTools( tools []llms.Tool, streamCb streaming.Callback, ) (*llms.ContentResponse, error) { + // Same Bedrock Converse API requirement as in CallEx: if no tools were + // explicitly provided for this turn but the chain already carries toolUse / + // toolResult blocks, reconstruct minimal definitions so that the library + // includes toolConfig in the request. + if len(tools) == 0 { + tools = buildMinimalToolsFromChain(chain) + } + return provider.WrapGenerateContent( ctx, p, opt, p.llm.GenerateContent, chain, append([]llms.CallOption{ @@ -199,3 +220,53 @@ func (p *bedrockProvider) GetUsage(info map[string]any) pconfig.CallUsage { func (p *bedrockProvider) GetToolCallIDTemplate(ctx context.Context, prompter templates.Prompter) (string, error) { return provider.DetermineToolCallIDTemplate(ctx, p, pconfig.OptionsTypeSimple, prompter) } + +// buildMinimalToolsFromChain inspects a conversation chain for ToolCall and +// ToolCallResponse parts and returns minimal llms.Tool definitions for every +// unique tool name found. This is required by the AWS Bedrock Converse API: +// whenever the request messages contain toolUse or toolResult content blocks, +// the request MUST also include a valid toolConfig — otherwise the API returns +// +// ValidationException: The toolConfig field must be defined when using +// toolUse and toolResult content blocks. +// +// We intentionally build placeholder definitions (no real parameter schema) +// because Bedrock only validates that the config is present, not that the +// schemas match historical usage. +func buildMinimalToolsFromChain(chain []llms.MessageContent) []llms.Tool { + seen := make(map[string]struct{}) + for _, msg := range chain { + for _, part := range msg.Parts { + switch p := part.(type) { + case llms.ToolCall: + if p.FunctionCall != nil && p.FunctionCall.Name != "" { + seen[p.FunctionCall.Name] = struct{}{} + } + case llms.ToolCallResponse: + if p.Name != "" { + seen[p.Name] = struct{}{} + } + } + } + } + + if len(seen) == 0 { + return nil + } + + tools := make([]llms.Tool, 0, len(seen)) + for name := range seen { + tools = append(tools, llms.Tool{ + Type: "function", + Function: &llms.FunctionDefinition{ + Name: name, + Description: fmt.Sprintf("Tool: %s", name), + Parameters: map[string]any{ + "type": "object", + "properties": map[string]any{}, + }, + }, + }) + } + return tools +} diff --git a/backend/pkg/providers/bedrock/bedrock_test.go b/backend/pkg/providers/bedrock/bedrock_test.go index c77a44c8..6c8db015 100644 --- a/backend/pkg/providers/bedrock/bedrock_test.go +++ b/backend/pkg/providers/bedrock/bedrock_test.go @@ -1,12 +1,15 @@ package bedrock import ( + "fmt" + "sort" "testing" "pentagi/pkg/config" "pentagi/pkg/providers/pconfig" "pentagi/pkg/providers/provider" + "github.com/vxcontrol/langchaingo/llms" "github.com/vxcontrol/langchaingo/llms/bedrock" ) @@ -181,3 +184,252 @@ func TestGetUsage(t *testing.T) { t.Errorf("Expected zero tokens with empty usage info, got %s", usage.String()) } } + +// TestBuildMinimalToolsFromChain verifies the helper that reconstructs minimal +// tool definitions from a conversation chain. This is the foundation of the +// fix for the AWS Bedrock Converse API ValidationException that occurs when +// messages contain toolUse / toolResult blocks but no toolConfig is provided. +func TestBuildMinimalToolsFromChain(t *testing.T) { + t.Run("empty chain returns nil", func(t *testing.T) { + result := buildMinimalToolsFromChain(nil) + if result != nil { + t.Errorf("expected nil for empty chain, got %v", result) + } + + result = buildMinimalToolsFromChain([]llms.MessageContent{}) + if result != nil { + t.Errorf("expected nil for empty chain slice, got %v", result) + } + }) + + t.Run("chain with only text messages returns nil", func(t *testing.T) { + chain := []llms.MessageContent{ + llms.TextParts(llms.ChatMessageTypeSystem, "You are helpful."), + llms.TextParts(llms.ChatMessageTypeHuman, "What is 2+2?"), + llms.TextParts(llms.ChatMessageTypeAI, "4"), + } + result := buildMinimalToolsFromChain(chain) + if result != nil { + t.Errorf("expected nil for text-only chain, got %v", result) + } + }) + + t.Run("chain with ToolCall returns tool definition", func(t *testing.T) { + chain := []llms.MessageContent{ + llms.TextParts(llms.ChatMessageTypeHuman, "What is the weather?"), + { + Role: llms.ChatMessageTypeAI, + Parts: []llms.ContentPart{ + llms.ToolCall{ + ID: "call_abc", + Type: "function", + FunctionCall: &llms.FunctionCall{ + Name: "get_weather", + Arguments: `{"location":"NYC"}`, + }, + }, + }, + }, + } + + result := buildMinimalToolsFromChain(chain) + if len(result) != 1 { + t.Fatalf("expected 1 tool definition, got %d", len(result)) + } + if result[0].Function == nil { + t.Fatal("tool Function must not be nil") + } + if result[0].Function.Name != "get_weather" { + t.Errorf("expected tool name 'get_weather', got '%s'", result[0].Function.Name) + } + if result[0].Type != "function" { + t.Errorf("expected type 'function', got '%s'", result[0].Type) + } + if result[0].Function.Parameters == nil { + t.Error("expected non-nil Parameters (placeholder schema)") + } + }) + + t.Run("chain with ToolCallResponse returns tool definition", func(t *testing.T) { + chain := []llms.MessageContent{ + { + Role: llms.ChatMessageTypeTool, + Parts: []llms.ContentPart{ + llms.ToolCallResponse{ + ToolCallID: "call_abc", + Name: "calculate", + Content: "42", + }, + }, + }, + } + + result := buildMinimalToolsFromChain(chain) + if len(result) != 1 { + t.Fatalf("expected 1 tool definition, got %d", len(result)) + } + if result[0].Function.Name != "calculate" { + t.Errorf("expected tool name 'calculate', got '%s'", result[0].Function.Name) + } + }) + + t.Run("chain with both ToolCall and ToolCallResponse deduplicates names", func(t *testing.T) { + chain := []llms.MessageContent{ + { + Role: llms.ChatMessageTypeAI, + Parts: []llms.ContentPart{ + llms.ToolCall{ + ID: "call_1", + Type: "function", + FunctionCall: &llms.FunctionCall{ + Name: "search", + Arguments: `{"query":"go"}`, + }, + }, + }, + }, + { + Role: llms.ChatMessageTypeTool, + Parts: []llms.ContentPart{ + llms.ToolCallResponse{ + ToolCallID: "call_1", + Name: "search", // same name — should deduplicate + Content: "results", + }, + }, + }, + } + + result := buildMinimalToolsFromChain(chain) + if len(result) != 1 { + t.Fatalf("expected 1 deduplicated tool definition, got %d (%v)", len(result), toolNames(result)) + } + if result[0].Function.Name != "search" { + t.Errorf("expected tool name 'search', got '%s'", result[0].Function.Name) + } + }) + + t.Run("chain with multiple distinct tool names returns all", func(t *testing.T) { + chain := []llms.MessageContent{ + { + Role: llms.ChatMessageTypeAI, + Parts: []llms.ContentPart{ + llms.ToolCall{ + ID: "c1", + Type: "function", + FunctionCall: &llms.FunctionCall{Name: "search"}, + }, + }, + }, + { + Role: llms.ChatMessageTypeTool, + Parts: []llms.ContentPart{ + llms.ToolCallResponse{ToolCallID: "c1", Name: "search", Content: "r"}, + }, + }, + { + Role: llms.ChatMessageTypeAI, + Parts: []llms.ContentPart{ + llms.ToolCall{ + ID: "c2", + Type: "function", + FunctionCall: &llms.FunctionCall{Name: "execute_command"}, + }, + }, + }, + { + Role: llms.ChatMessageTypeTool, + Parts: []llms.ContentPart{ + llms.ToolCallResponse{ToolCallID: "c2", Name: "execute_command", Content: "ok"}, + }, + }, + } + + result := buildMinimalToolsFromChain(chain) + if len(result) != 2 { + t.Fatalf("expected 2 tool definitions, got %d (%v)", len(result), toolNames(result)) + } + + names := toolNames(result) + sort.Strings(names) + expected := []string{"execute_command", "search"} + for i, n := range expected { + if names[i] != n { + t.Errorf("expected name[%d]=%s, got %s", i, n, names[i]) + } + } + }) + + t.Run("ToolCall without FunctionCall is skipped", func(t *testing.T) { + chain := []llms.MessageContent{ + { + Role: llms.ChatMessageTypeAI, + Parts: []llms.ContentPart{ + llms.ToolCall{ID: "c1", Type: "function"}, + // FunctionCall is nil — should be ignored + }, + }, + } + result := buildMinimalToolsFromChain(chain) + if result != nil { + t.Errorf("expected nil when ToolCall has nil FunctionCall, got %v", result) + } + }) + + t.Run("generated tool has valid placeholder schema", func(t *testing.T) { + chain := []llms.MessageContent{ + { + Role: llms.ChatMessageTypeAI, + Parts: []llms.ContentPart{ + llms.ToolCall{ + ID: "c1", + Type: "function", + FunctionCall: &llms.FunctionCall{ + Name: "my_tool", + Arguments: `{}`, + }, + }, + }, + }, + } + + result := buildMinimalToolsFromChain(chain) + if len(result) != 1 { + t.Fatalf("expected 1 tool, got %d", len(result)) + } + tool := result[0] + + // Verify placeholder description follows the expected format + expectedDesc := fmt.Sprintf("Tool: %s", "my_tool") + if tool.Function.Description != expectedDesc { + t.Errorf("expected description %q, got %q", expectedDesc, tool.Function.Description) + } + + // Verify placeholder schema has the required object type + schema, ok := tool.Function.Parameters.(map[string]any) + if !ok { + t.Fatalf("expected Parameters to be map[string]any, got %T", tool.Function.Parameters) + } + if schema["type"] != "object" { + t.Errorf("expected schema type 'object', got %v", schema["type"]) + } + props, ok := schema["properties"].(map[string]any) + if !ok { + t.Fatalf("expected properties to be map[string]any, got %T", schema["properties"]) + } + if len(props) != 0 { + t.Errorf("expected empty properties map, got %v", props) + } + }) +} + +// toolNames is a test helper that extracts tool names from a slice of llms.Tool. +func toolNames(tools []llms.Tool) []string { + names := make([]string, 0, len(tools)) + for _, t := range tools { + if t.Function != nil { + names = append(names, t.Function.Name) + } + } + return names +}