Skip to content

Commit cef5ae7

Browse files
Python: Support | and |= operators for KernelArgument (#12499)
### Motivation and Context `KernelArguments` extends the built-in `dict` by adding an `execution_settings` attribute. However, when using the `|` and `|=` operators, only the `dict` part (excluding `execution_settings`) is merged, and the result becomes a plain `dict`. This causes the `execution_settings` attribute to be lost and not updated. ### Description - Implemented support for the `|` and `|=` operators to make it easier for users to merge `KernelArgument` objects. - The right-hand or left-hand side of the operator can be either a `KernelArgument` or a `dict` (or any subclass of dict). - When merging, both the `execution_settings` attribute and the rest of the `dict` data are merged individually, and the result is returned as a `KernelArgument`. - As long as either side of the operator is a `KernelArgument`, the result will also be a `KernelArgument`. ### Contribution Checklist <!-- Before submitting this PR, please make sure: --> - [x] The code builds clean without any errors or warnings - [x] The PR follows the [SK Contribution Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md) and the [pre-submission formatting script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts) raises no violations - [x] All unit tests pass, and I have added new tests where possible - [x] I didn't break anyone 😄
1 parent 206655a commit cef5ae7

File tree

2 files changed

+187
-0
lines changed

2 files changed

+187
-0
lines changed

python/semantic_kernel/functions/kernel_arguments.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@
55
from semantic_kernel.const import DEFAULT_SERVICE_NAME
66

77
if TYPE_CHECKING:
8+
from collections.abc import Iterable
9+
10+
from _typeshed import SupportsKeysAndGetItem
11+
812
from semantic_kernel.connectors.ai.prompt_execution_settings import PromptExecutionSettings
913

1014

@@ -49,3 +53,53 @@ def __bool__(self) -> bool:
4953
has_arguments = self.__len__() > 0
5054
has_execution_settings = self.execution_settings is not None and len(self.execution_settings) > 0
5155
return has_arguments or has_execution_settings
56+
57+
def __or__(self, value: dict) -> "KernelArguments":
58+
"""Merges a KernelArguments with another KernelArguments or dict.
59+
60+
This implements the `|` operator for KernelArguments.
61+
"""
62+
if not isinstance(value, dict):
63+
raise TypeError(
64+
f"TypeError: unsupported operand type(s) for |: '{type(self).__name__}' and '{type(value).__name__}'"
65+
)
66+
67+
# Merge execution settings
68+
new_execution_settings = (self.execution_settings or {}).copy()
69+
if isinstance(value, KernelArguments) and value.execution_settings:
70+
new_execution_settings |= value.execution_settings
71+
# Create a new KernelArguments with merged dict values
72+
return KernelArguments(settings=new_execution_settings, **(dict(self) | dict(value)))
73+
74+
def __ror__(self, value: dict) -> "KernelArguments":
75+
"""Merges a dict with a KernelArguments.
76+
77+
This implements the right-side `|` operator for KernelArguments.
78+
"""
79+
if not isinstance(value, dict):
80+
raise TypeError(
81+
f"TypeError: unsupported operand type(s) for |: '{type(value).__name__}' and '{type(self).__name__}'"
82+
)
83+
84+
# Merge execution settings
85+
new_execution_settings = {}
86+
if isinstance(value, KernelArguments) and value.execution_settings:
87+
new_execution_settings = value.execution_settings.copy()
88+
if self.execution_settings:
89+
new_execution_settings |= self.execution_settings
90+
91+
# Create a new KernelArguments with merged dict values
92+
return KernelArguments(settings=new_execution_settings, **(dict(value) | dict(self)))
93+
94+
def __ior__(self, value: "SupportsKeysAndGetItem[Any, Any] | Iterable[tuple[Any, Any]]") -> "KernelArguments":
95+
"""Merges into this KernelArguments with another KernelArguments or dict (in-place)."""
96+
self.update(value)
97+
98+
# In-place merge execution settings
99+
if isinstance(value, KernelArguments) and value.execution_settings:
100+
if self.execution_settings:
101+
self.execution_settings.update(value.execution_settings)
102+
else:
103+
self.execution_settings = value.execution_settings.copy()
104+
105+
return self

python/tests/unit/functions/test_kernel_arguments.py

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# Copyright (c) Microsoft. All rights reserved.
22

3+
import pytest
4+
35
from semantic_kernel.connectors.ai.prompt_execution_settings import PromptExecutionSettings
46
from semantic_kernel.functions.kernel_arguments import KernelArguments
57

@@ -46,3 +48,134 @@ def test_kernel_arguments_bool():
4648
assert KernelArguments(settings=PromptExecutionSettings(service_id="test"))
4749
# An KernelArguments object with both keyword arguments and execution_settings should return True
4850
assert KernelArguments(input=10, settings=PromptExecutionSettings(service_id="test"))
51+
52+
53+
@pytest.mark.parametrize(
54+
"lhs, rhs, expected_dict, expected_settings_keys",
55+
[
56+
# Merging different keys
57+
(KernelArguments(a=1), KernelArguments(b=2), {"a": 1, "b": 2}, None),
58+
# RHS overwrites when keys duplicate
59+
(KernelArguments(a=1), KernelArguments(a=99), {"a": 99}, None),
60+
# Merging with a plain dict
61+
(KernelArguments(a=1), {"b": 2}, {"a": 1, "b": 2}, None),
62+
# Merging execution_settings together
63+
(
64+
KernelArguments(settings=PromptExecutionSettings(service_id="s1")),
65+
KernelArguments(settings=PromptExecutionSettings(service_id="s2")),
66+
{},
67+
["s1", "s2"],
68+
),
69+
# Same service_id is overwritten by RHS
70+
(
71+
KernelArguments(settings=PromptExecutionSettings(service_id="shared")),
72+
KernelArguments(settings=PromptExecutionSettings(service_id="shared")),
73+
{},
74+
["shared"],
75+
),
76+
],
77+
)
78+
def test_kernel_arguments_or_operator(lhs, rhs, expected_dict, expected_settings_keys):
79+
"""Test the __or__ operator (lhs | rhs) with various argument combinations."""
80+
result = lhs | rhs
81+
assert isinstance(result, KernelArguments)
82+
assert dict(result) == expected_dict
83+
if expected_settings_keys is None:
84+
assert result.execution_settings is None
85+
else:
86+
assert sorted(result.execution_settings.keys()) == sorted(expected_settings_keys)
87+
88+
89+
@pytest.mark.parametrize("rhs", [42, "foo", None])
90+
def test_kernel_arguments_or_operator_with_invalid_type(rhs):
91+
"""Test the __or__ operator with an invalid type raises TypeError."""
92+
with pytest.raises(TypeError):
93+
KernelArguments() | rhs
94+
95+
96+
@pytest.mark.parametrize(
97+
"lhs, rhs, expected_dict, expected_settings_keys",
98+
[
99+
# Dict merge (in-place)
100+
(KernelArguments(a=1), {"b": 2}, {"a": 1, "b": 2}, None),
101+
# Merging between KernelArguments
102+
(KernelArguments(a=1), KernelArguments(b=2), {"a": 1, "b": 2}, None),
103+
# Retain existing execution_settings after dict merge
104+
(KernelArguments(a=1, settings=PromptExecutionSettings(service_id="s1")), {"b": 2}, {"a": 1, "b": 2}, ["s1"]),
105+
# In-place merge of execution_settings
106+
(
107+
KernelArguments(settings=PromptExecutionSettings(service_id="s1")),
108+
KernelArguments(settings=PromptExecutionSettings(service_id="s2")),
109+
{},
110+
["s1", "s2"],
111+
),
112+
],
113+
)
114+
def test_kernel_arguments_inplace_merge(lhs, rhs, expected_dict, expected_settings_keys):
115+
"""Test the |= operator with various argument combinations without execution_settings."""
116+
original_id = id(lhs)
117+
lhs |= rhs
118+
# Verify this is the same object (in-place)
119+
assert id(lhs) == original_id
120+
assert dict(lhs) == expected_dict
121+
if expected_settings_keys is None:
122+
assert lhs.execution_settings is None
123+
else:
124+
assert sorted(lhs.execution_settings.keys()) == sorted(expected_settings_keys)
125+
126+
127+
@pytest.mark.parametrize(
128+
"rhs, lhs, expected_dict, expected_settings_keys",
129+
[
130+
# Merging different keys
131+
({"b": 2}, KernelArguments(a=1), {"b": 2, "a": 1}, None),
132+
# RHS overwrites when keys duplicate
133+
({"a": 1}, KernelArguments(a=99), {"a": 99}, None),
134+
# Merging with a KernelArguments
135+
({"b": 2}, KernelArguments(a=1), {"b": 2, "a": 1}, None),
136+
# Merging execution_settings together
137+
(
138+
{"test": "value"},
139+
KernelArguments(settings=PromptExecutionSettings(service_id="s2")),
140+
{"test": "value"},
141+
["s2"],
142+
),
143+
# Plain dict on the left with KernelArguments+settings on the right
144+
(
145+
{"a": 1},
146+
KernelArguments(b=2, settings=PromptExecutionSettings(service_id="shared")),
147+
{"a": 1, "b": 2},
148+
["shared"],
149+
),
150+
# KernelArguments on both sides with execution_settings
151+
(
152+
KernelArguments(a=1, settings=PromptExecutionSettings(service_id="s1")),
153+
KernelArguments(b=2, settings=PromptExecutionSettings(service_id="s2")),
154+
{"a": 1, "b": 2},
155+
["s1", "s2"],
156+
),
157+
# Same service_id is overwritten by RHS (KernelArguments)
158+
(
159+
KernelArguments(a=1, settings=PromptExecutionSettings(service_id="shared")),
160+
KernelArguments(b=2, settings=PromptExecutionSettings(service_id="shared")),
161+
{"a": 1, "b": 2},
162+
["shared"],
163+
),
164+
],
165+
)
166+
def test_kernel_arguments_ror_operator(rhs, lhs, expected_dict, expected_settings_keys):
167+
"""Test the __ror__ operator (lhs | rhs) with various argument combinations."""
168+
result = rhs | lhs
169+
assert isinstance(result, KernelArguments)
170+
assert dict(result) == expected_dict
171+
if expected_settings_keys is None:
172+
assert result.execution_settings is None
173+
else:
174+
assert sorted(result.execution_settings.keys()) == sorted(expected_settings_keys)
175+
176+
177+
@pytest.mark.parametrize("lhs", [42, "foo", None])
178+
def test_kernel_arguments_ror_operator_with_invalid_type(lhs):
179+
"""Test the __ror__ operator with an invalid type raises TypeError."""
180+
with pytest.raises(TypeError):
181+
lhs | KernelArguments()

0 commit comments

Comments
 (0)