Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 74 additions & 3 deletions backend/pkg/providers/bedrock/bedrock.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,11 +168,24 @@ func (p *bedrockProvider) CallEx(
chain []llms.MessageContent,
streamCb streaming.Callback,
) (*llms.ContentResponse, error) {
options := []llms.CallOption{
llms.WithStreamingFunc(streamCb),
}

// The AWS Bedrock Converse API requires toolConfig to be defined whenever the
// conversation history contains toolUse or toolResult content blocks — even when
// no new tools are being offered in the current turn. Without it the API returns:
// ValidationException: The toolConfig field must be defined when using
// toolUse and toolResult content blocks.
// We reconstruct minimal tool definitions from the tool-call names already
// present in the chain so that the library sets toolConfig automatically.
if minimalTools := buildMinimalToolsFromChain(chain); len(minimalTools) > 0 {
options = append(options, llms.WithTools(minimalTools))
}

return provider.WrapGenerateContent(
ctx, p, opt, p.llm.GenerateContent, chain,
append([]llms.CallOption{
llms.WithStreamingFunc(streamCb),
}, p.providerConfig.GetOptionsForType(opt)...)...,
append(options, p.providerConfig.GetOptionsForType(opt)...)...,
)
}

Expand All @@ -183,6 +196,14 @@ func (p *bedrockProvider) CallWithTools(
tools []llms.Tool,
streamCb streaming.Callback,
) (*llms.ContentResponse, error) {
// Same Bedrock Converse API requirement as in CallEx: if no tools were
// explicitly provided for this turn but the chain already carries toolUse /
// toolResult blocks, reconstruct minimal definitions so that the library
// includes toolConfig in the request.
if len(tools) == 0 {
tools = buildMinimalToolsFromChain(chain)
}

return provider.WrapGenerateContent(
ctx, p, opt, p.llm.GenerateContent, chain,
append([]llms.CallOption{
Expand All @@ -199,3 +220,53 @@ func (p *bedrockProvider) GetUsage(info map[string]any) pconfig.CallUsage {
func (p *bedrockProvider) GetToolCallIDTemplate(ctx context.Context, prompter templates.Prompter) (string, error) {
return provider.DetermineToolCallIDTemplate(ctx, p, pconfig.OptionsTypeSimple, prompter)
}

// buildMinimalToolsFromChain inspects a conversation chain for ToolCall and
// ToolCallResponse parts and returns minimal llms.Tool definitions for every
// unique tool name found. This is required by the AWS Bedrock Converse API:
// whenever the request messages contain toolUse or toolResult content blocks,
// the request MUST also include a valid toolConfig — otherwise the API returns
//
// ValidationException: The toolConfig field must be defined when using
// toolUse and toolResult content blocks.
//
// We intentionally build placeholder definitions (no real parameter schema)
// because Bedrock only validates that the config is present, not that the
// schemas match historical usage.
func buildMinimalToolsFromChain(chain []llms.MessageContent) []llms.Tool {
seen := make(map[string]struct{})
for _, msg := range chain {
for _, part := range msg.Parts {
switch p := part.(type) {
case llms.ToolCall:
if p.FunctionCall != nil && p.FunctionCall.Name != "" {
seen[p.FunctionCall.Name] = struct{}{}
}
case llms.ToolCallResponse:
if p.Name != "" {
seen[p.Name] = struct{}{}
}
}
}
}

if len(seen) == 0 {
return nil
}

tools := make([]llms.Tool, 0, len(seen))
for name := range seen {
tools = append(tools, llms.Tool{
Type: "function",
Function: &llms.FunctionDefinition{
Name: name,
Description: fmt.Sprintf("Tool: %s", name),
Parameters: map[string]any{
"type": "object",
"properties": map[string]any{},
},
},
})
}
return tools
}
252 changes: 252 additions & 0 deletions backend/pkg/providers/bedrock/bedrock_test.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
package bedrock

import (
"fmt"
"sort"
"testing"

"pentagi/pkg/config"
"pentagi/pkg/providers/pconfig"
"pentagi/pkg/providers/provider"

"github.com/vxcontrol/langchaingo/llms"
"github.com/vxcontrol/langchaingo/llms/bedrock"
)

Expand Down Expand Up @@ -181,3 +184,252 @@ func TestGetUsage(t *testing.T) {
t.Errorf("Expected zero tokens with empty usage info, got %s", usage.String())
}
}

// TestBuildMinimalToolsFromChain verifies the helper that reconstructs minimal
// tool definitions from a conversation chain. This is the foundation of the
// fix for the AWS Bedrock Converse API ValidationException that occurs when
// messages contain toolUse / toolResult blocks but no toolConfig is provided.
func TestBuildMinimalToolsFromChain(t *testing.T) {
t.Run("empty chain returns nil", func(t *testing.T) {
result := buildMinimalToolsFromChain(nil)
if result != nil {
t.Errorf("expected nil for empty chain, got %v", result)
}

result = buildMinimalToolsFromChain([]llms.MessageContent{})
if result != nil {
t.Errorf("expected nil for empty chain slice, got %v", result)
}
})

t.Run("chain with only text messages returns nil", func(t *testing.T) {
chain := []llms.MessageContent{
llms.TextParts(llms.ChatMessageTypeSystem, "You are helpful."),
llms.TextParts(llms.ChatMessageTypeHuman, "What is 2+2?"),
llms.TextParts(llms.ChatMessageTypeAI, "4"),
}
result := buildMinimalToolsFromChain(chain)
if result != nil {
t.Errorf("expected nil for text-only chain, got %v", result)
}
})

t.Run("chain with ToolCall returns tool definition", func(t *testing.T) {
chain := []llms.MessageContent{
llms.TextParts(llms.ChatMessageTypeHuman, "What is the weather?"),
{
Role: llms.ChatMessageTypeAI,
Parts: []llms.ContentPart{
llms.ToolCall{
ID: "call_abc",
Type: "function",
FunctionCall: &llms.FunctionCall{
Name: "get_weather",
Arguments: `{"location":"NYC"}`,
},
},
},
},
}

result := buildMinimalToolsFromChain(chain)
if len(result) != 1 {
t.Fatalf("expected 1 tool definition, got %d", len(result))
}
if result[0].Function == nil {
t.Fatal("tool Function must not be nil")
}
if result[0].Function.Name != "get_weather" {
t.Errorf("expected tool name 'get_weather', got '%s'", result[0].Function.Name)
}
if result[0].Type != "function" {
t.Errorf("expected type 'function', got '%s'", result[0].Type)
}
if result[0].Function.Parameters == nil {
t.Error("expected non-nil Parameters (placeholder schema)")
}
})

t.Run("chain with ToolCallResponse returns tool definition", func(t *testing.T) {
chain := []llms.MessageContent{
{
Role: llms.ChatMessageTypeTool,
Parts: []llms.ContentPart{
llms.ToolCallResponse{
ToolCallID: "call_abc",
Name: "calculate",
Content: "42",
},
},
},
}

result := buildMinimalToolsFromChain(chain)
if len(result) != 1 {
t.Fatalf("expected 1 tool definition, got %d", len(result))
}
if result[0].Function.Name != "calculate" {
t.Errorf("expected tool name 'calculate', got '%s'", result[0].Function.Name)
}
})

t.Run("chain with both ToolCall and ToolCallResponse deduplicates names", func(t *testing.T) {
chain := []llms.MessageContent{
{
Role: llms.ChatMessageTypeAI,
Parts: []llms.ContentPart{
llms.ToolCall{
ID: "call_1",
Type: "function",
FunctionCall: &llms.FunctionCall{
Name: "search",
Arguments: `{"query":"go"}`,
},
},
},
},
{
Role: llms.ChatMessageTypeTool,
Parts: []llms.ContentPart{
llms.ToolCallResponse{
ToolCallID: "call_1",
Name: "search", // same name — should deduplicate
Content: "results",
},
},
},
}

result := buildMinimalToolsFromChain(chain)
if len(result) != 1 {
t.Fatalf("expected 1 deduplicated tool definition, got %d (%v)", len(result), toolNames(result))
}
if result[0].Function.Name != "search" {
t.Errorf("expected tool name 'search', got '%s'", result[0].Function.Name)
}
})

t.Run("chain with multiple distinct tool names returns all", func(t *testing.T) {
chain := []llms.MessageContent{
{
Role: llms.ChatMessageTypeAI,
Parts: []llms.ContentPart{
llms.ToolCall{
ID: "c1",
Type: "function",
FunctionCall: &llms.FunctionCall{Name: "search"},
},
},
},
{
Role: llms.ChatMessageTypeTool,
Parts: []llms.ContentPart{
llms.ToolCallResponse{ToolCallID: "c1", Name: "search", Content: "r"},
},
},
{
Role: llms.ChatMessageTypeAI,
Parts: []llms.ContentPart{
llms.ToolCall{
ID: "c2",
Type: "function",
FunctionCall: &llms.FunctionCall{Name: "execute_command"},
},
},
},
{
Role: llms.ChatMessageTypeTool,
Parts: []llms.ContentPart{
llms.ToolCallResponse{ToolCallID: "c2", Name: "execute_command", Content: "ok"},
},
},
}

result := buildMinimalToolsFromChain(chain)
if len(result) != 2 {
t.Fatalf("expected 2 tool definitions, got %d (%v)", len(result), toolNames(result))
}

names := toolNames(result)
sort.Strings(names)
expected := []string{"execute_command", "search"}
for i, n := range expected {
if names[i] != n {
t.Errorf("expected name[%d]=%s, got %s", i, n, names[i])
}
}
})

t.Run("ToolCall without FunctionCall is skipped", func(t *testing.T) {
chain := []llms.MessageContent{
{
Role: llms.ChatMessageTypeAI,
Parts: []llms.ContentPart{
llms.ToolCall{ID: "c1", Type: "function"},
// FunctionCall is nil — should be ignored
},
},
}
result := buildMinimalToolsFromChain(chain)
if result != nil {
t.Errorf("expected nil when ToolCall has nil FunctionCall, got %v", result)
}
})

t.Run("generated tool has valid placeholder schema", func(t *testing.T) {
chain := []llms.MessageContent{
{
Role: llms.ChatMessageTypeAI,
Parts: []llms.ContentPart{
llms.ToolCall{
ID: "c1",
Type: "function",
FunctionCall: &llms.FunctionCall{
Name: "my_tool",
Arguments: `{}`,
},
},
},
},
}

result := buildMinimalToolsFromChain(chain)
if len(result) != 1 {
t.Fatalf("expected 1 tool, got %d", len(result))
}
tool := result[0]

// Verify placeholder description follows the expected format
expectedDesc := fmt.Sprintf("Tool: %s", "my_tool")
if tool.Function.Description != expectedDesc {
t.Errorf("expected description %q, got %q", expectedDesc, tool.Function.Description)
}

// Verify placeholder schema has the required object type
schema, ok := tool.Function.Parameters.(map[string]any)
if !ok {
t.Fatalf("expected Parameters to be map[string]any, got %T", tool.Function.Parameters)
}
if schema["type"] != "object" {
t.Errorf("expected schema type 'object', got %v", schema["type"])
}
props, ok := schema["properties"].(map[string]any)
if !ok {
t.Fatalf("expected properties to be map[string]any, got %T", schema["properties"])
}
if len(props) != 0 {
t.Errorf("expected empty properties map, got %v", props)
}
})
}

// toolNames is a test helper that extracts tool names from a slice of llms.Tool.
func toolNames(tools []llms.Tool) []string {
names := make([]string, 0, len(tools))
for _, t := range tools {
if t.Function != nil {
names = append(names, t.Function.Name)
}
}
return names
}