Skip to content

Commit 8a2582c

Browse files
authored
SK KernelFunction from ToolSchemas (#6637)
## Why are these changes needed? Only a subset of available tools will sent to SK ## Related issue number resolves #6582 ## Checks - [ ] I've included any doc changes needed for <https://microsoft.github.io/autogen/>. See <https://github.com/microsoft/autogen/blob/main/CONTRIBUTING.md> to build and test documentation locally. - [ ] I've added tests (if relevant) corresponding to the changes introduced in this PR. - [ ] I've made sure all auto checks have passed.
1 parent 348bcb1 commit 8a2582c

File tree

4 files changed

+74
-5
lines changed

4 files changed

+74
-5
lines changed

python/packages/autogen-ext/src/autogen_ext/models/semantic_kernel/_sk_chat_completion_adapter.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from semantic_kernel.kernel import Kernel
3232
from typing_extensions import AsyncGenerator, Union
3333

34-
from autogen_ext.tools.semantic_kernel import KernelFunctionFromTool
34+
from autogen_ext.tools.semantic_kernel import KernelFunctionFromTool, KernelFunctionFromToolSchema
3535

3636
from .._utils.parse_r1_content import parse_r1_content
3737

@@ -396,6 +396,9 @@ def _sync_tools_with_kernel(self, kernel: Kernel, tools: Sequence[Tool | ToolSch
396396
# Convert Tool to KernelFunction using KernelFunctionFromTool
397397
kernel_function = KernelFunctionFromTool(tool) # type: ignore
398398
self._tools_plugin.functions[tool.schema["name"]] = kernel_function
399+
else:
400+
kernel_function = KernelFunctionFromToolSchema(tool) # type: ignore
401+
self._tools_plugin.functions[tool.get("name")] = kernel_function # type: ignore
399402

400403
kernel.add_plugin(self._tools_plugin)
401404

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
from ._kernel_function_from_tool import KernelFunctionFromTool
1+
from ._kernel_function_from_tool import KernelFunctionFromTool, KernelFunctionFromToolSchema
22

33
__all__ = [
44
"KernelFunctionFromTool",
5+
"KernelFunctionFromToolSchema",
56
]

python/packages/autogen-ext/src/autogen_ext/tools/semantic_kernel/_kernel_function_from_tool.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
from typing import Any, TypeVar
22

33
from autogen_core import CancellationToken
4-
from autogen_core.tools import BaseTool
4+
from autogen_core.tools import BaseTool, ToolSchema
55
from pydantic import BaseModel
6-
from semantic_kernel.functions import KernelFunctionFromMethod, kernel_function
6+
from semantic_kernel.functions import KernelFunctionFromMethod, KernelFunctionFromPrompt, kernel_function
77
from semantic_kernel.functions.kernel_parameter_metadata import KernelParameterMetadata
8+
from semantic_kernel.prompt_template.input_variable import InputVariable
9+
from semantic_kernel.prompt_template.prompt_template_config import PromptTemplateConfig
810

911
InputT = TypeVar("InputT", bound=BaseModel)
1012
OutputT = TypeVar("OutputT", bound=BaseModel)
@@ -65,3 +67,27 @@ async def tool_method(**kwargs: dict[str, Any]) -> Any:
6567
)
6668

6769
self._tool = tool
70+
71+
72+
class KernelFunctionFromToolSchema(KernelFunctionFromPrompt):
73+
def __init__(self, tool_schema: ToolSchema, plugin_name: str | None = None):
74+
properties = tool_schema.get("parameters", {}).get("properties", {})
75+
required = properties.get("required", [])
76+
77+
prompt_template_config = PromptTemplateConfig(
78+
name=tool_schema.get("name", ""),
79+
description=tool_schema.get("description", ""),
80+
input_variables=[
81+
InputVariable(
82+
name=prop_name, description=prop_info.get("description", ""), is_required=prop_name in required
83+
)
84+
for prop_name, prop_info in properties.items()
85+
],
86+
)
87+
88+
super().__init__(
89+
function_name=tool_schema.get("name", ""),
90+
plugin_name=plugin_name,
91+
description=tool_schema.get("description", ""),
92+
prompt_template_config=prompt_template_config,
93+
)

python/packages/autogen-ext/tests/models/test_sk_chat_completion_adapter.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
SystemMessage,
1717
UserMessage,
1818
)
19-
from autogen_core.tools import BaseTool
19+
from autogen_core.tools import BaseTool, ParametersSchema, ToolSchema
2020
from autogen_ext.models.semantic_kernel import SKChatCompletionAdapter
2121
from openai.types.chat.chat_completion_chunk import (
2222
ChatCompletionChunk,
@@ -335,6 +335,45 @@ async def test_sk_chat_completion_with_tools(sk_client: AzureChatCompletion) ->
335335
assert not result.cached
336336

337337

338+
@pytest.mark.asyncio
339+
async def test_sk_chat_completion_with_prompt_tools(sk_client: AzureChatCompletion) -> None:
340+
# Create adapter
341+
adapter = SKChatCompletionAdapter(sk_client)
342+
343+
# Create kernel
344+
kernel = Kernel(memory=NullMemory())
345+
346+
# Create calculator tool instance
347+
tool: ToolSchema = ToolSchema(
348+
name="calculator",
349+
description="Add two numbers together",
350+
parameters=ParametersSchema(
351+
type="object",
352+
properties={
353+
"a": {"type": "number", "description": "First number"},
354+
"b": {"type": "number", "description": "Second number"},
355+
},
356+
required=["a", "b"],
357+
),
358+
)
359+
360+
# Test messages
361+
messages: list[LLMMessage] = [
362+
SystemMessage(content="You are a helpful assistant."),
363+
UserMessage(content="What is 2 + 2?", source="user"),
364+
]
365+
366+
# Call create with tool
367+
result = await adapter.create(messages=messages, tools=[tool], extra_create_args={"kernel": kernel})
368+
369+
# Verify response
370+
assert isinstance(result.content, list)
371+
assert result.finish_reason == "function_calls"
372+
assert result.usage.prompt_tokens >= 0
373+
assert result.usage.completion_tokens >= 0
374+
assert not result.cached
375+
376+
338377
@pytest.mark.asyncio
339378
async def test_sk_chat_completion_without_tools(
340379
sk_client: AzureChatCompletion, caplog: pytest.LogCaptureFixture

0 commit comments

Comments
 (0)