diff --git a/src/strands/_async.py b/src/strands/_async.py index 976487c37..141ca71b7 100644 --- a/src/strands/_async.py +++ b/src/strands/_async.py @@ -1,6 +1,7 @@ """Private async execution utilities.""" import asyncio +import contextvars from concurrent.futures import ThreadPoolExecutor from typing import Awaitable, Callable, TypeVar @@ -27,5 +28,6 @@ def execute() -> T: return asyncio.run(execute_async()) with ThreadPoolExecutor() as executor: - future = executor.submit(execute) + context = contextvars.copy_context() + future = executor.submit(context.run, execute) return future.result() diff --git a/tests_integ/tools/__init__.py b/tests_integ/tools/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests_integ/tools/test_thread_context.py b/tests_integ/tools/test_thread_context.py new file mode 100644 index 000000000..b86c9b2c0 --- /dev/null +++ b/tests_integ/tools/test_thread_context.py @@ -0,0 +1,47 @@ +import contextvars + +import pytest + +from strands import Agent, tool + + +@pytest.fixture +def result(): + return {} + + +@pytest.fixture +def contextvar(): + return contextvars.ContextVar("agent") + + +@pytest.fixture +def context_tool(result, contextvar): + @tool(name="context_tool") + def tool_(): + result["context_value"] = contextvar.get("local_context") + + return tool_ + + +@pytest.fixture +def agent(context_tool): + return Agent(tools=[context_tool]) + + +def test_agent_invoke_context_sharing(result, contextvar, agent): + contextvar.set("shared_context") + agent("Execute context_tool") + + tru_context = result["context_value"] + exp_context = contextvar.get() + assert tru_context == exp_context + + +def test_tool_call_context_sharing(result, contextvar, agent): + contextvar.set("shared_context") + agent.tool.context_tool() + + tru_context = result["context_value"] + exp_context = contextvar.get() + assert tru_context == exp_context