Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/strands/_async.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Private async execution utilities."""

import asyncio
import contextvars
from concurrent.futures import ThreadPoolExecutor
from typing import Awaitable, Callable, TypeVar

Expand All @@ -27,5 +28,6 @@ def execute() -> T:
return asyncio.run(execute_async())

with ThreadPoolExecutor() as executor:
future = executor.submit(execute)
context = contextvars.copy_context()
future = executor.submit(context.run, execute)
return future.result()
Empty file added tests_integ/tools/__init__.py
Empty file.
47 changes: 47 additions & 0 deletions tests_integ/tools/test_thread_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import contextvars

import pytest

from strands import Agent, tool


@pytest.fixture
def result():
return {}


@pytest.fixture
def contextvar():
return contextvars.ContextVar("agent")


@pytest.fixture
def context_tool(result, contextvar):
@tool(name="context_tool")
def tool_():
result["context_value"] = contextvar.get("local_context")

return tool_


@pytest.fixture
def agent(context_tool):
return Agent(tools=[context_tool])


def test_agent_invoke_context_sharing(result, contextvar, agent):
contextvar.set("shared_context")
agent("Execute context_tool")

tru_context = result["context_value"]
exp_context = contextvar.get()
assert tru_context == exp_context


def test_tool_call_context_sharing(result, contextvar, agent):
contextvar.set("shared_context")
agent.tool.context_tool()

tru_context = result["context_value"]
exp_context = contextvar.get()
assert tru_context == exp_context
Loading