Skip to content

Commit

Permalink
fix: simplify ContextVar and fix sub-span attribution when delegate…
Browse files Browse the repository at this point in the history
…d to thread (#13700)

* fix: simplify contextvar

* test trees in parallel

* fix Token.MISSING
  • Loading branch information
RogerHYang committed May 29, 2024
1 parent 03bb37f commit 1127908
Show file tree
Hide file tree
Showing 39 changed files with 310 additions and 365 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import asyncio
from itertools import chain
from threading import Thread
from typing import (
Any,
AsyncGenerator,
Expand Down Expand Up @@ -42,6 +41,7 @@
from llama_index.core.settings import Settings
from llama_index.core.tools import BaseTool, ToolOutput, adapt_to_async_tool
from llama_index.core.tools.types import AsyncBaseTool
from llama_index.core.types import Thread
from llama_index.core.utils import print_text, unit_generator


Expand Down
2 changes: 1 addition & 1 deletion llama-index-core/llama_index/core/agent/react/step.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import asyncio
import uuid
from functools import partial
from threading import Thread
from typing import (
Any,
AsyncGenerator,
Expand Down Expand Up @@ -53,6 +52,7 @@
from llama_index.core.settings import Settings
from llama_index.core.tools import BaseTool, ToolOutput, adapt_to_async_tool
from llama_index.core.tools.types import AsyncBaseTool
from llama_index.core.types import Thread
from llama_index.core.utils import print_text


Expand Down
40 changes: 16 additions & 24 deletions llama-index-core/llama_index/core/agent/runner/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,9 +392,9 @@ def _run_step(
**kwargs: Any,
) -> TaskStepOutput:
"""Execute step."""
dispatch_event = dispatcher.get_dispatch_event()

dispatch_event(AgentRunStepStartEvent(task_id=task_id, step=step, input=input))
dispatcher.event(
AgentRunStepStartEvent(task_id=task_id, step=step, input=input)
)
task = self.state.get_task(task_id)
step_queue = self.state.get_step_queue(task_id)
step = step or step_queue.popleft()
Expand All @@ -421,7 +421,7 @@ def _run_step(
completed_steps = self.state.get_completed_steps(task_id)
completed_steps.append(cur_step_output)

dispatch_event(AgentRunStepEndEvent(step_output=cur_step_output))
dispatcher.event(AgentRunStepEndEvent(step_output=cur_step_output))
return cur_step_output

@dispatcher.span
Expand All @@ -434,9 +434,9 @@ async def _arun_step(
**kwargs: Any,
) -> TaskStepOutput:
"""Execute step."""
dispatch_event = dispatcher.get_dispatch_event()

dispatch_event(AgentRunStepStartEvent(task_id=task_id, step=step, input=input))
dispatcher.event(
AgentRunStepStartEvent(task_id=task_id, step=step, input=input)
)
task = self.state.get_task(task_id)
step_queue = self.state.get_step_queue(task_id)
step = step or step_queue.popleft()
Expand All @@ -462,7 +462,7 @@ async def _arun_step(
completed_steps = self.state.get_completed_steps(task_id)
completed_steps.append(cur_step_output)

dispatch_event(AgentRunStepEndEvent(step_output=cur_step_output))
dispatcher.event(AgentRunStepEndEvent(step_output=cur_step_output))
return cur_step_output

@dispatcher.span
Expand Down Expand Up @@ -561,14 +561,12 @@ def _chat(
mode: ChatResponseMode = ChatResponseMode.WAIT,
) -> AGENT_CHAT_RESPONSE_TYPE:
"""Chat with step executor."""
dispatch_event = dispatcher.get_dispatch_event()

if chat_history is not None:
self.memory.set(chat_history)
task = self.create_task(message)

result_output = None
dispatch_event(AgentChatWithStepStartEvent(user_msg=message))
dispatcher.event(AgentChatWithStepStartEvent(user_msg=message))
while True:
# pass step queue in as argument, assume step executor is stateless
cur_step_output = self._run_step(
Expand All @@ -586,7 +584,7 @@ def _chat(
task.task_id,
result_output,
)
dispatch_event(AgentChatWithStepEndEvent(response=result))
dispatcher.event(AgentChatWithStepEndEvent(response=result))
return result

@dispatcher.span
Expand All @@ -598,14 +596,12 @@ async def _achat(
mode: ChatResponseMode = ChatResponseMode.WAIT,
) -> AGENT_CHAT_RESPONSE_TYPE:
"""Chat with step executor."""
dispatch_event = dispatcher.get_dispatch_event()

if chat_history is not None:
self.memory.set(chat_history)
task = self.create_task(message)

result_output = None
dispatch_event(AgentChatWithStepStartEvent(user_msg=message))
dispatcher.event(AgentChatWithStepStartEvent(user_msg=message))
while True:
# pass step queue in as argument, assume step executor is stateless
cur_step_output = await self._arun_step(
Expand All @@ -623,7 +619,7 @@ async def _achat(
task.task_id,
result_output,
)
dispatch_event(AgentChatWithStepEndEvent(response=result))
dispatcher.event(AgentChatWithStepEndEvent(response=result))
return result

@dispatcher.span
Expand Down Expand Up @@ -778,16 +774,14 @@ def _chat(
mode: ChatResponseMode = ChatResponseMode.WAIT,
) -> AGENT_CHAT_RESPONSE_TYPE:
"""Chat with step executor."""
dispatch_event = dispatcher.get_dispatch_event()

if chat_history is not None:
self.memory.set(chat_history)

# create initial set of tasks
plan_id = self.create_plan(message)

results = []
dispatch_event(AgentChatWithStepStartEvent(user_msg=message))
dispatcher.event(AgentChatWithStepStartEvent(user_msg=message))
while True:
# EXIT CONDITION: check if all sub-tasks are completed
next_task_ids = self.get_next_tasks(plan_id)
Expand All @@ -812,7 +806,7 @@ def _chat(
# refine the plan
self.refine_plan(message, plan_id)

dispatch_event(
dispatcher.event(
AgentChatWithStepEndEvent(
response=results[-1] if len(results) > 0 else None
)
Expand All @@ -828,16 +822,14 @@ async def _achat(
mode: ChatResponseMode = ChatResponseMode.WAIT,
) -> AGENT_CHAT_RESPONSE_TYPE:
"""Chat with step executor."""
dispatch_event = dispatcher.get_dispatch_event()

if chat_history is not None:
self.memory.set(chat_history)

# create initial set of tasks
plan_id = self.create_plan(message)

results = []
dispatch_event(AgentChatWithStepStartEvent(user_msg=message))
dispatcher.event(AgentChatWithStepStartEvent(user_msg=message))
while True:
# EXIT CONDITION: check if all sub-tasks are completed
next_task_ids = self.get_next_tasks(plan_id)
Expand All @@ -862,7 +854,7 @@ async def _achat(
# refine the plan
await self.arefine_plan(message, plan_id)

dispatch_event(
dispatcher.event(
AgentChatWithStepEndEvent(
response=results[-1] if len(results) > 0 else None
)
Expand Down
3 changes: 1 addition & 2 deletions llama-index-core/llama_index/core/async_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,9 @@ async def run_jobs(
List[Any]:
List of results.
"""
parent_span_id = dispatcher.current_span_id
semaphore = asyncio.Semaphore(workers)

@dispatcher.async_span_with_parent_id(parent_id=parent_span_id)
@dispatcher.span
async def worker(job: Coroutine) -> Any:
async with semaphore:
return await job
Expand Down
16 changes: 8 additions & 8 deletions llama-index-core/llama_index/core/base/base_query_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,26 +44,26 @@ def _update_prompts(self, prompts: PromptDictType) -> None:

@dispatcher.span
def query(self, str_or_query_bundle: QueryType) -> RESPONSE_TYPE:
dispatch_event = dispatcher.get_dispatch_event()

dispatch_event(QueryStartEvent(query=str_or_query_bundle))
dispatcher.event(QueryStartEvent(query=str_or_query_bundle))
with self.callback_manager.as_trace("query"):
if isinstance(str_or_query_bundle, str):
str_or_query_bundle = QueryBundle(str_or_query_bundle)
query_result = self._query(str_or_query_bundle)
dispatch_event(QueryEndEvent(query=str_or_query_bundle, response=query_result))
dispatcher.event(
QueryEndEvent(query=str_or_query_bundle, response=query_result)
)
return query_result

@dispatcher.span
async def aquery(self, str_or_query_bundle: QueryType) -> RESPONSE_TYPE:
dispatch_event = dispatcher.get_dispatch_event()

dispatch_event(QueryStartEvent(query=str_or_query_bundle))
dispatcher.event(QueryStartEvent(query=str_or_query_bundle))
with self.callback_manager.as_trace("query"):
if isinstance(str_or_query_bundle, str):
str_or_query_bundle = QueryBundle(str_or_query_bundle)
query_result = await self._aquery(str_or_query_bundle)
dispatch_event(QueryEndEvent(query=str_or_query_bundle, response=query_result))
dispatcher.event(
QueryEndEvent(query=str_or_query_bundle, response=query_result)
)
return query_result

def retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
Expand Down
11 changes: 4 additions & 7 deletions llama-index-core/llama_index/core/base/base_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,10 +224,8 @@ def retrieve(self, str_or_query_bundle: QueryType) -> List[NodeWithScore]:
a QueryBundle object.
"""
dispatch_event = dispatcher.get_dispatch_event()

self._check_callback_manager()
dispatch_event(
dispatcher.event(
RetrievalStartEvent(
str_or_query_bundle=str_or_query_bundle,
)
Expand All @@ -246,7 +244,7 @@ def retrieve(self, str_or_query_bundle: QueryType) -> List[NodeWithScore]:
retrieve_event.on_end(
payload={EventPayload.NODES: nodes},
)
dispatch_event(
dispatcher.event(
RetrievalEndEvent(
str_or_query_bundle=str_or_query_bundle,
nodes=nodes,
Expand All @@ -257,9 +255,8 @@ def retrieve(self, str_or_query_bundle: QueryType) -> List[NodeWithScore]:
@dispatcher.span
async def aretrieve(self, str_or_query_bundle: QueryType) -> List[NodeWithScore]:
self._check_callback_manager()
dispatch_event = dispatcher.get_dispatch_event()

dispatch_event(
dispatcher.event(
RetrievalStartEvent(
str_or_query_bundle=str_or_query_bundle,
)
Expand All @@ -280,7 +277,7 @@ async def aretrieve(self, str_or_query_bundle: QueryType) -> List[NodeWithScore]
retrieve_event.on_end(
payload={EventPayload.NODES: nodes},
)
dispatch_event(
dispatcher.event(
RetrievalEndEvent(
str_or_query_bundle=str_or_query_bundle,
nodes=nodes,
Expand Down
Loading

0 comments on commit 1127908

Please sign in to comment.