Skip to content

Commit

Permalink
Add span_id attribute to Events (instrumentation) (#12417)
Browse files Browse the repository at this point in the history
* add span_id to Event

* remove raise err in NullHandler

* wip

* modify root dispatcher event enclosing span

* remove *args as we have bound_args now

* add LLMChatInProgressEvent

* add LLMStructuredPredict Eventst

* store span_id before await executions

* add SpanDropEvent with err_str payload

* add event to _achat; flush current_span_id when open_spans is empty

* llm callbacks use root span_id

* add unit tests

* remove print statements

* provide context manager returning a distpatch event partial with correct span id

* move to context manager usage

* fix invocation of cm

* define and use get_dispatch_event method

* remove aim tests
  • Loading branch information
nerdai committed Apr 2, 2024
1 parent 88750ab commit 95be930
Show file tree
Hide file tree
Showing 22 changed files with 617 additions and 132 deletions.
23 changes: 18 additions & 5 deletions llama-index-core/llama_index/core/agent/runner/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,9 @@ def _run_step(
**kwargs: Any,
) -> TaskStepOutput:
"""Execute step."""
dispatcher.event(AgentRunStepStartEvent())
dispatch_event = dispatcher.get_dispatch_event()

dispatch_event(AgentRunStepStartEvent())
task = self.state.get_task(task_id)
step_queue = self.state.get_step_queue(task_id)
step = step or step_queue.popleft()
Expand All @@ -392,7 +394,7 @@ def _run_step(
completed_steps = self.state.get_completed_steps(task_id)
completed_steps.append(cur_step_output)

dispatcher.event(AgentRunStepEndEvent())
dispatch_event(AgentRunStepEndEvent())
return cur_step_output

@dispatcher.span
Expand All @@ -405,6 +407,9 @@ async def _arun_step(
**kwargs: Any,
) -> TaskStepOutput:
"""Execute step."""
dispatch_event = dispatcher.get_dispatch_event()

dispatch_event(AgentRunStepStartEvent())
task = self.state.get_task(task_id)
step_queue = self.state.get_step_queue(task_id)
step = step or step_queue.popleft()
Expand All @@ -430,6 +435,7 @@ async def _arun_step(
completed_steps = self.state.get_completed_steps(task_id)
completed_steps.append(cur_step_output)

dispatch_event(AgentRunStepEndEvent())
return cur_step_output

@dispatcher.span
Expand Down Expand Up @@ -528,12 +534,14 @@ def _chat(
mode: ChatResponseMode = ChatResponseMode.WAIT,
) -> AGENT_CHAT_RESPONSE_TYPE:
"""Chat with step executor."""
dispatch_event = dispatcher.get_dispatch_event()

if chat_history is not None:
self.memory.set(chat_history)
task = self.create_task(message)

result_output = None
dispatcher.event(AgentChatWithStepStartEvent())
dispatch_event(AgentChatWithStepStartEvent())
while True:
# pass step queue in as argument, assume step executor is stateless
cur_step_output = self._run_step(
Expand All @@ -551,7 +559,7 @@ def _chat(
task.task_id,
result_output,
)
dispatcher.event(AgentChatWithStepEndEvent())
dispatch_event(AgentChatWithStepEndEvent())
return result

@dispatcher.span
Expand All @@ -563,11 +571,14 @@ async def _achat(
mode: ChatResponseMode = ChatResponseMode.WAIT,
) -> AGENT_CHAT_RESPONSE_TYPE:
"""Chat with step executor."""
dispatch_event = dispatcher.get_dispatch_event()

if chat_history is not None:
self.memory.set(chat_history)
task = self.create_task(message)

result_output = None
dispatch_event(AgentChatWithStepStartEvent())
while True:
# pass step queue in as argument, assume step executor is stateless
cur_step_output = await self._arun_step(
Expand All @@ -581,10 +592,12 @@ async def _achat(
# ensure tool_choice does not cause endless loops
tool_choice = "auto"

return self.finalize_response(
result = self.finalize_response(
task.task_id,
result_output,
)
dispatch_event(AgentChatWithStepEndEvent())
return result

@dispatcher.span
@trace_method("chat")
Expand Down
12 changes: 8 additions & 4 deletions llama-index-core/llama_index/core/base/base_query_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,22 +44,26 @@ def _update_prompts(self, prompts: PromptDictType) -> None:

@dispatcher.span
def query(self, str_or_query_bundle: QueryType) -> RESPONSE_TYPE:
dispatcher.event(QueryStartEvent())
dispatch_event = dispatcher.get_dispatch_event()

dispatch_event(QueryStartEvent())
with self.callback_manager.as_trace("query"):
if isinstance(str_or_query_bundle, str):
str_or_query_bundle = QueryBundle(str_or_query_bundle)
query_result = self._query(str_or_query_bundle)
dispatcher.event(QueryEndEvent())
dispatch_event(QueryEndEvent())
return query_result

@dispatcher.span
async def aquery(self, str_or_query_bundle: QueryType) -> RESPONSE_TYPE:
dispatcher.event(QueryStartEvent())
dispatch_event = dispatcher.get_dispatch_event()

dispatch_event(QueryStartEvent())
with self.callback_manager.as_trace("query"):
if isinstance(str_or_query_bundle, str):
str_or_query_bundle = QueryBundle(str_or_query_bundle)
query_result = await self._aquery(str_or_query_bundle)
dispatcher.event(QueryEndEvent())
dispatch_event(QueryEndEvent())
return query_result

def retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
Expand Down
30 changes: 24 additions & 6 deletions llama-index-core/llama_index/core/base/base_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,8 +224,14 @@ def retrieve(self, str_or_query_bundle: QueryType) -> List[NodeWithScore]:
a QueryBundle object.
"""
dispatch_event = dispatcher.get_dispatch_event()

self._check_callback_manager()
dispatcher.event(RetrievalStartEvent(str_or_query_bundle=str_or_query_bundle))
dispatch_event(
RetrievalStartEvent(
str_or_query_bundle=str_or_query_bundle,
)
)
if isinstance(str_or_query_bundle, str):
query_bundle = QueryBundle(str_or_query_bundle)
else:
Expand All @@ -240,15 +246,24 @@ def retrieve(self, str_or_query_bundle: QueryType) -> List[NodeWithScore]:
retrieve_event.on_end(
payload={EventPayload.NODES: nodes},
)
dispatcher.event(
RetrievalEndEvent(str_or_query_bundle=str_or_query_bundle, nodes=nodes)
dispatch_event(
RetrievalEndEvent(
str_or_query_bundle=str_or_query_bundle,
nodes=nodes,
)
)
return nodes

@dispatcher.span
async def aretrieve(self, str_or_query_bundle: QueryType) -> List[NodeWithScore]:
self._check_callback_manager()
dispatcher.event(RetrievalStartEvent(str_or_query_bundle=str_or_query_bundle))
dispatch_event = dispatcher.get_dispatch_event()

dispatch_event(
RetrievalStartEvent(
str_or_query_bundle=str_or_query_bundle,
)
)
if isinstance(str_or_query_bundle, str):
query_bundle = QueryBundle(str_or_query_bundle)
else:
Expand All @@ -265,8 +280,11 @@ async def aretrieve(self, str_or_query_bundle: QueryType) -> List[NodeWithScore]
retrieve_event.on_end(
payload={EventPayload.NODES: nodes},
)
dispatcher.event(
RetrievalEndEvent(str_or_query_bundle=str_or_query_bundle, nodes=nodes)
dispatch_event(
RetrievalEndEvent(
str_or_query_bundle=str_or_query_bundle,
nodes=nodes,
)
)
return nodes

Expand Down
90 changes: 74 additions & 16 deletions llama-index-core/llama_index/core/base/embeddings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,13 @@ def get_query_embedding(self, query: str) -> Embedding:
other examples of predefined instructions can be found in
embeddings/huggingface_utils.py.
"""
dispatcher.event(EmbeddingStartEvent(model_dict=self.to_dict()))
dispatch_event = dispatcher.get_dispatch_event()

dispatch_event(
EmbeddingStartEvent(
model_dict=self.to_dict(),
)
)
with self.callback_manager.event(
CBEventType.EMBEDDING, payload={EventPayload.SERIALIZED: self.to_dict()}
) as event:
Expand All @@ -126,15 +132,24 @@ def get_query_embedding(self, query: str) -> Embedding:
EventPayload.EMBEDDINGS: [query_embedding],
},
)
dispatcher.event(
EmbeddingEndEvent(chunks=[query], embeddings=[query_embedding])
dispatch_event(
EmbeddingEndEvent(
chunks=[query],
embeddings=[query_embedding],
)
)
return query_embedding

@dispatcher.span
async def aget_query_embedding(self, query: str) -> Embedding:
"""Get query embedding."""
dispatcher.event(EmbeddingStartEvent(model_dict=self.to_dict()))
dispatch_event = dispatcher.get_dispatch_event()

dispatch_event(
EmbeddingStartEvent(
model_dict=self.to_dict(),
)
)
with self.callback_manager.event(
CBEventType.EMBEDDING, payload={EventPayload.SERIALIZED: self.to_dict()}
) as event:
Expand All @@ -146,8 +161,11 @@ async def aget_query_embedding(self, query: str) -> Embedding:
EventPayload.EMBEDDINGS: [query_embedding],
},
)
dispatcher.event(
EmbeddingEndEvent(chunks=[query], embeddings=[query_embedding])
dispatch_event(
EmbeddingEndEvent(
chunks=[query],
embeddings=[query_embedding],
)
)
return query_embedding

Expand Down Expand Up @@ -220,7 +238,13 @@ def get_text_embedding(self, text: str) -> Embedding:
document for retrieval: ". If you're curious, other examples of
predefined instructions can be found in embeddings/huggingface_utils.py.
"""
dispatcher.event(EmbeddingStartEvent(model_dict=self.to_dict()))
dispatch_event = dispatcher.get_dispatch_event()

dispatch_event(
EmbeddingStartEvent(
model_dict=self.to_dict(),
)
)
with self.callback_manager.event(
CBEventType.EMBEDDING, payload={EventPayload.SERIALIZED: self.to_dict()}
) as event:
Expand All @@ -232,13 +256,24 @@ def get_text_embedding(self, text: str) -> Embedding:
EventPayload.EMBEDDINGS: [text_embedding],
}
)
dispatcher.event(EmbeddingEndEvent(chunks=[text], embeddings=[text_embedding]))
dispatch_event(
EmbeddingEndEvent(
chunks=[text],
embeddings=[text_embedding],
)
)
return text_embedding

@dispatcher.span
async def aget_text_embedding(self, text: str) -> Embedding:
"""Async get text embedding."""
dispatcher.event(EmbeddingStartEvent(model_dict=self.to_dict()))
dispatch_event = dispatcher.get_dispatch_event()

dispatch_event(
EmbeddingStartEvent(
model_dict=self.to_dict(),
)
)
with self.callback_manager.event(
CBEventType.EMBEDDING, payload={EventPayload.SERIALIZED: self.to_dict()}
) as event:
Expand All @@ -250,7 +285,12 @@ async def aget_text_embedding(self, text: str) -> Embedding:
EventPayload.EMBEDDINGS: [text_embedding],
}
)
dispatcher.event(EmbeddingEndEvent(chunks=[text], embeddings=[text_embedding]))
dispatch_event(
EmbeddingEndEvent(
chunks=[text],
embeddings=[text_embedding],
)
)
return text_embedding

@dispatcher.span
Expand All @@ -261,6 +301,8 @@ def get_text_embedding_batch(
**kwargs: Any,
) -> List[Embedding]:
"""Get a list of text embeddings, with batching."""
dispatch_event = dispatcher.get_dispatch_event()

cur_batch: List[str] = []
result_embeddings: List[Embedding] = []

Expand All @@ -272,7 +314,11 @@ def get_text_embedding_batch(
cur_batch.append(text)
if idx == len(texts) - 1 or len(cur_batch) == self.embed_batch_size:
# flush
dispatcher.event(EmbeddingStartEvent(model_dict=self.to_dict()))
dispatch_event(
EmbeddingStartEvent(
model_dict=self.to_dict(),
)
)
with self.callback_manager.event(
CBEventType.EMBEDDING,
payload={EventPayload.SERIALIZED: self.to_dict()},
Expand All @@ -285,8 +331,11 @@ def get_text_embedding_batch(
EventPayload.EMBEDDINGS: embeddings,
},
)
dispatcher.event(
EmbeddingEndEvent(chunks=cur_batch, embeddings=embeddings)
dispatch_event(
EmbeddingEndEvent(
chunks=cur_batch,
embeddings=embeddings,
)
)
cur_batch = []

Expand All @@ -297,6 +346,8 @@ async def aget_text_embedding_batch(
self, texts: List[str], show_progress: bool = False
) -> List[Embedding]:
"""Asynchronously get a list of text embeddings, with batching."""
dispatch_event = dispatcher.get_dispatch_event()

cur_batch: List[str] = []
callback_payloads: List[Tuple[str, List[str]]] = []
result_embeddings: List[Embedding] = []
Expand All @@ -305,7 +356,11 @@ async def aget_text_embedding_batch(
cur_batch.append(text)
if idx == len(texts) - 1 or len(cur_batch) == self.embed_batch_size:
# flush
dispatcher.event(EmbeddingStartEvent(model_dict=self.to_dict()))
dispatch_event(
EmbeddingStartEvent(
model_dict=self.to_dict(),
)
)
event_id = self.callback_manager.on_event_start(
CBEventType.EMBEDDING,
payload={EventPayload.SERIALIZED: self.to_dict()},
Expand Down Expand Up @@ -337,8 +392,11 @@ async def aget_text_embedding_batch(
for (event_id, text_batch), embeddings in zip(
callback_payloads, nested_embeddings
):
dispatcher.event(
EmbeddingEndEvent(chunks=text_batch, embeddings=embeddings)
dispatch_event(
EmbeddingEndEvent(
chunks=text_batch,
embeddings=embeddings,
)
)
self.callback_manager.on_event_end(
CBEventType.EMBEDDING,
Expand Down

0 comments on commit 95be930

Please sign in to comment.