Skip to content

.Net: Adds support for Tool calls to the Amazon Bedrock connector #11922

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,16 @@

using System;
using System.Collections.Generic;
using System.ComponentModel;
using System.IO;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Amazon.BedrockRuntime;
using Amazon.BedrockRuntime.Model;
using Amazon.Runtime.Documents;
using Amazon.Runtime.Endpoints;
using Amazon.Runtime.EventStreams;
using Microsoft.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel.Services;
using Moq;
Expand Down Expand Up @@ -271,6 +275,74 @@ public async Task GetChatMessageContentsAsyncShouldHaveProperChatHistoryAsync()
Assert.Equal("That's great to hear!", chatHistory[5].Items[0].ToString());
}

/// <summary>
/// Checks that the chat history with binary content is given the correct values through calling GetChatMessageContentsAsync.
/// </summary>
[Fact]
public async Task GetChatMessageContentsWithBinaryContentAsyncShouldHaveProperChatHistoryAsync()
{
// Arrange
string modelId = "amazon.titan-embed-text-v1:0";
var mockBedrockApi = new Mock<IAmazonBedrockRuntime>();
mockBedrockApi.Setup(m => m.DetermineServiceOperationEndpoint(It.IsAny<ConverseRequest>()))
.Returns(new Endpoint("https://bedrock-runtime.us-east-1.amazonaws.com")
{
URL = "https://bedrock-runtime.us-east-1.amazonaws.com"
});

// Set up the mock ConverseAsync to return multiple responses
mockBedrockApi.SetupSequence(m => m.ConverseAsync(It.IsAny<ConverseRequest>(), It.IsAny<CancellationToken>()))
.ReturnsAsync(this.CreateConverseResponse("Here is the result.", ConversationRole.Assistant));

var kernel = Kernel.CreateBuilder().AddBedrockChatCompletionService(modelId, mockBedrockApi.Object).Build();
var service = kernel.GetRequiredService<IChatCompletionService>();
var chatHistory = CreateSampleChatHistoryWithBinaryContent();

// Act
var result = await service.GetChatMessageContentsAsync(chatHistory).ConfigureAwait(true);

// Assert
string? chatResult = result[0].Content;
Assert.NotNull(chatResult);

// Check the first result
Assert.Equal(AuthorRole.Assistant, result[0].Role);
Assert.Single(result[0].Items);
Assert.Equal("Here is the result.", result[0].Items[0].ToString());

// Check the chat history
Assert.Equal(6, chatHistory.Count); // Use the Count property to get the number of messages

Assert.Equal(AuthorRole.System, chatHistory[0].Role);
Assert.Equal("You are an AI Assistant", chatHistory[0].Items[0].ToString());

Assert.Equal(AuthorRole.User, chatHistory[1].Role); // Use the indexer to access individual messages
Assert.Equal("Hello", chatHistory[1].Items[0].ToString());

Assert.Equal(AuthorRole.Assistant, chatHistory[2].Role);
Assert.Equal("Hi", chatHistory[2].Items[0].ToString());

Assert.Equal(AuthorRole.User, chatHistory[3].Role);
Assert.Equal("How are you?", chatHistory[3].Items[0].ToString());

Assert.Equal(AuthorRole.Assistant, chatHistory[4].Role);
Assert.Equal("Fine, thanks. How can I help?", chatHistory[4].Items[0].ToString());

Assert.Equal(AuthorRole.User, chatHistory[5].Role);
Assert.Collection(chatHistory[5].Items,
c =>
{
Assert.IsType<TextContent>(c);
var item = (TextContent)c;
Assert.Equal("I need you to summarize these attachments.", item.Text);
},
c => Assert.IsType<ImageContent>(c),
c => Assert.IsType<PdfContent>(c),
c => Assert.IsType<DocxContent>(c),
c => Assert.IsType<ImageContent>(c)
);
}

/// <summary>
/// Checks that error handling present for empty chat history.
/// </summary>
Expand Down Expand Up @@ -379,6 +451,189 @@ public async Task ShouldHandleEmptyChatHistoryMessagesAsync()
// and doesn't throw an exception
}

private sealed class TestPlugin
{
[KernelFunction()]
[Description("Given a document title, look up the corresponding document ID for it.")]
[return: Description("The identified document if found, or an empty string if not.")]
public string FindDocumentIdForTitle(
[Description("The title to retrieve a corresponding ID for")]
string title
)
{
return $"{title}-{Guid.NewGuid()}";
}
}

[Theory]
[InlineData(false)]
[InlineData(true)]
public async Task ShouldHandleToolsInConverseRequestAsync(bool required)
{
// Arrange
ConverseRequest? firstRequest = null;
ConverseRequest? secondRequest = null;
var mockBedrockApi = new Mock<IAmazonBedrockRuntime>();
mockBedrockApi.Setup(m => m.DetermineServiceOperationEndpoint(It.IsAny<ConverseRequest>()))
.Returns(new Endpoint("https://bedrock-runtime.us-east-1.amazonaws.com")
{
URL = "https://bedrock-runtime.us-east-1.amazonaws.com"
});
mockBedrockApi.Setup(m => m.ConverseAsync(It.IsAny<ConverseRequest>(), It.IsAny<CancellationToken>()))
.Callback((ConverseRequest request, CancellationToken token) =>
{
if (firstRequest == null)
{
firstRequest = request;
}
else
{
secondRequest = request;
}
})
.ReturnsAsync((ConverseRequest request, CancellationToken _) =>
{
return secondRequest == null
? new ConverseResponse
{
Output = new ConverseOutput
{
Message = new Message
{
Role = ConversationRole.Assistant,
Content = [ new() { ToolUse = new ToolUseBlock
{
ToolUseId = "tool-use-id-1",
Name = "TestPlugin-FindDocumentIdForTitle",
Input = Document.FromObject(new Dictionary<string, object>
{
["title"] = "Green Eggs and Ham",
}),
} } ]
},
},
Metrics = new ConverseMetrics(),
StopReason = StopReason.Tool_use,
Usage = new TokenUsage()
}
: this.CreateConverseResponse("Hello, world!", ConversationRole.Assistant);
});
var kernel = Kernel.CreateBuilder().AddBedrockChatCompletionService("amazon.titan-text-premier-v1:0", mockBedrockApi.Object).Build();
var plugin = new TestPlugin();
kernel.ImportPluginFromObject(plugin);
var chatHistory = new ChatHistory();
chatHistory.AddUserMessage("Find the ID corresponding to the title 'Green Eggs and Ham', by Dr. Suess.");
var service = kernel.GetRequiredService<IChatCompletionService>();
var executionSettings = AmazonClaudeExecutionSettings.FromExecutionSettings(null);
executionSettings.FunctionChoiceBehavior = required ? FunctionChoiceBehavior.Required() : FunctionChoiceBehavior.Auto();

// Act
var result = await service.GetChatMessageContentsAsync(chatHistory, executionSettings, kernel, CancellationToken.None).ConfigureAwait(true);

Assert.NotNull(firstRequest?.ToolConfig);
if (required)
{
Assert.NotNull(firstRequest.ToolConfig.ToolChoice);
Assert.Null(firstRequest.ToolConfig.ToolChoice.Auto);
Assert.Equal("TestPlugin-FindDocumentIdForTitle", firstRequest.ToolConfig.ToolChoice?.Tool?.Name);
}
else // auto
{
Assert.NotNull(firstRequest.ToolConfig.ToolChoice?.Auto);
}
Assert.NotNull(secondRequest?.Messages.Last().Content?.FirstOrDefault(c => c.ToolResult != null));
}

[Fact(Skip = "This test is missing the binary stream containing the delta block events with tool use needed to test this API")]
public async Task ShouldHandleToolsInConverseStreamingRequestAsync()
{
// Arrange
ConverseStreamRequest? firstRequest = null;
ConverseStreamRequest? secondRequest = null;
var mockBedrockApi = new Mock<IAmazonBedrockRuntime>();
mockBedrockApi.Setup(m => m.DetermineServiceOperationEndpoint(It.IsAny<ConverseStreamRequest>()))
.Returns(new Endpoint("https://bedrock-runtime.us-east-1.amazonaws.com")
{
URL = "https://bedrock-runtime.us-east-1.amazonaws.com"
});
List<IEventStreamEvent> firstSequence = [
new ContentBlockStartEvent
{
ContentBlockIndex = 0,
Start = new ContentBlockStart
{
ToolUse = new ToolUseBlockStart
{
ToolUseId = "tool-use-id-1",
Name = "TestPlugin-FindDocumentIdForTitle",
}
}
},
new ContentBlockDeltaEvent
{
ContentBlockIndex = 1,
Delta = new ContentBlockDelta
{
ToolUse = new ToolUseBlockDelta
{
Input = """
{
"title": "Green Eggs and Ham"
}
""",
}
}
},
new ContentBlockStopEvent
{
ContentBlockIndex = 2,
}
];
mockBedrockApi.Setup(m => m.ConverseStreamAsync(It.IsAny<ConverseStreamRequest>(), It.IsAny<CancellationToken>()))
.Callback((ConverseStreamRequest request, CancellationToken token) =>
{
if (firstRequest == null)
{
firstRequest = request;
}
else
{
secondRequest = request;
}
})
.ReturnsAsync((ConverseStreamRequest request, CancellationToken _) =>
{
return new ConverseStreamResponse
{
HttpStatusCode = System.Net.HttpStatusCode.OK,
// TODO: Replace with actual stream containing the delta block events with tool use
Stream = new ConverseStreamOutput(new MemoryStream())
};
});

var kernel = Kernel.CreateBuilder().AddBedrockChatCompletionService("amazon.titan-text-premier-v1:0", mockBedrockApi.Object).Build();
var chatHistory = new ChatHistory();
chatHistory.AddUserMessage("Stream the ID corresponding to the title 'Green Eggs and Ham', by Dr. Suess.");
var service = kernel.GetRequiredService<IChatCompletionService>();
var executionSettings = new AmazonClaudeExecutionSettings
{
ModelId = "amazon.titan-text-premier-v1:0",
FunctionChoiceBehavior = FunctionChoiceBehavior.Auto(),
};

// Act
var result = service.GetStreamingChatMessageContentsAsync(chatHistory, executionSettings, kernel, CancellationToken.None).ConfigureAwait(true);
var stream = new List<StreamingChatMessageContent>();
await foreach (var msg in result)
{
stream.Add(msg);
}

// Assert
Assert.NotNull(firstRequest?.ToolConfig);
Assert.NotNull(secondRequest?.Messages.Last().Content?.FirstOrDefault(c => c.ToolResult != null));
}

private static ChatHistory CreateSampleChatHistory()
{
var chatHistory = new ChatHistory();
Expand All @@ -389,11 +644,38 @@ private static ChatHistory CreateSampleChatHistory()
return chatHistory;
}

private byte[] GetTestResponseAsBytes(string fileName)
private static ChatHistory CreateSampleChatHistoryWithBinaryContent()
{
var chatHistory = new ChatHistory();
chatHistory.AddSystemMessage("You are an AI Assistant");
chatHistory.AddUserMessage("Hello");
chatHistory.AddAssistantMessage("Hi");
chatHistory.AddUserMessage("How are you?");
chatHistory.AddAssistantMessage("Fine, thanks. How can I help?");
chatHistory.AddUserMessage(
[
new TextContent("I need you to summarize these attachments."),
new ImageContent(new Uri("https://example.com/image.jpg")),
new PdfContent(GetTestDataFileContentsAsBase64String("SemanticKernelCookBook.en.pdf", "application/pdf")),
new DocxContent(GetTestDataFileContentsAsBase64String("test-doc.docx", "application/vnd.openxmlformats-officedocument.wordprocessingml.document")),
new ImageContent(GetTestDataFileContentsAsBase64String("the-planner.png", "image/png")),
]);
return chatHistory;
}

private byte[] GetTestResponseAsBytes(string fileName) => GetTestDataFileContentsAsBytes(fileName);

private static byte[] GetTestDataFileContentsAsBytes(string fileName)
{
return File.ReadAllBytes($"TestData/{fileName}");
}

private static string GetTestDataFileContentsAsBase64String(string fileName, string mimeType)
{
var content = Convert.ToBase64String(GetTestDataFileContentsAsBytes(fileName));
return $"data:{mimeType};base64,{content}";
}

private ConverseResponse CreateConverseResponse(string text, ConversationRole role)
{
return new ConverseResponse
Expand Down
Binary file not shown.
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading