Skip to content

Commit

Permalink
Fix asyncio.run() in async context. (#13309)
Browse files Browse the repository at this point in the history
  • Loading branch information
Falven committed May 7, 2024
1 parent 8cb9690 commit ce95bd7
Show file tree
Hide file tree
Showing 20 changed files with 62 additions and 51 deletions.
2 changes: 1 addition & 1 deletion llama-index-core/llama_index/core/agent/react/step.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,7 +728,7 @@ async def _arun_step_stream(
# wait until response writing is done
agent_response._ensure_async_setup()

await agent_response._is_function_false_event.wait()
await agent_response.is_function_false_event.wait()

return self._get_task_step_response(agent_response, step, is_done)

Expand Down
5 changes: 2 additions & 3 deletions llama-index-core/llama_index/core/agent/runner/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import asyncio
import os
from abc import abstractmethod
from collections import deque
Expand All @@ -11,7 +10,7 @@
TaskStep,
TaskStepOutput,
)
from llama_index.core.async_utils import run_jobs
from llama_index.core.async_utils import asyncio_run, run_jobs
from llama_index.core.bridge.pydantic import BaseModel, Field
from llama_index.core.callbacks import (
CallbackManager,
Expand Down Expand Up @@ -799,7 +798,7 @@ def _chat(
self.arun_task(sub_task_id, mode=mode, tool_choice=tool_choice)
for sub_task_id in next_task_ids
]
results = asyncio.run(run_jobs(jobs, workers=len(jobs)))
results = asyncio_run(run_jobs(jobs, workers=len(jobs)))

for sub_task_id in next_task_ids:
self.mark_task_complete(plan_id, sub_task_id)
Expand Down
3 changes: 2 additions & 1 deletion llama-index-core/llama_index/core/agent/runner/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
TaskStep,
TaskStepOutput,
)
from llama_index.core.async_utils import asyncio_run
from llama_index.core.bridge.pydantic import BaseModel, Field
from llama_index.core.callbacks import (
CallbackManager,
Expand Down Expand Up @@ -183,7 +184,7 @@ def run_steps_in_queue(
Assume that all steps can be run in parallel.
"""
return asyncio.run(self.arun_steps_in_queue(task_id, mode=mode, **kwargs))
return asyncio_run(self.arun_steps_in_queue(task_id, mode=mode, **kwargs))

async def arun_steps_in_queue(
self,
Expand Down
14 changes: 13 additions & 1 deletion llama-index-core/llama_index/core/async_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,18 @@ def asyncio_module(show_progress: bool = False) -> Any:
return module


def asyncio_run(coro: Coroutine) -> Any:
try:
loop = asyncio.get_running_loop()
return (
asyncio.ensure_future(coro)
if loop.is_running()
else loop.run_until_complete(coro)
)
except RuntimeError:
return asyncio.run(coro)


def run_async_tasks(
tasks: List[Coroutine],
show_progress: bool = False,
Expand Down Expand Up @@ -51,7 +63,7 @@ async def _tqdm_gather() -> List[Any]:
async def _gather() -> List[Any]:
return await asyncio.gather(*tasks_to_execute)

outputs: List[Any] = asyncio.run(_gather())
outputs: List[Any] = asyncio_run(_gather())
return outputs


Expand Down
3 changes: 2 additions & 1 deletion llama-index-core/llama_index/core/base/response/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union

from llama_index.core.async_utils import asyncio_run
from llama_index.core.bridge.pydantic import BaseModel
from llama_index.core.schema import NodeWithScore
from llama_index.core.types import TokenGen, TokenAsyncGen
Expand Down Expand Up @@ -179,7 +180,7 @@ def __post_init__(self) -> None:

def __str__(self) -> str:
"""Convert to string representation."""
return asyncio.run(self._async_str)
return asyncio_run(self._async_str)

async def _async_str(self) -> str:
"""Convert to string representation."""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import asyncio
import logging
from threading import Thread
from typing import Any, List, Optional, Tuple

from llama_index.core.async_utils import asyncio_run
from llama_index.core.base.llms.types import ChatMessage, MessageRole
from llama_index.core.callbacks import CallbackManager, trace_method
from llama_index.core.chat_engine.types import (
Expand Down Expand Up @@ -358,7 +358,7 @@ async def astream_chat(
source_nodes=context_nodes,
)
thread = Thread(
target=lambda x: asyncio.run(chat_response.awrite_response_to_history(x)),
target=lambda x: asyncio_run(chat_response.awrite_response_to_history(x)),
args=(self._memory,),
)
thread.start()
Expand Down
4 changes: 2 additions & 2 deletions llama-index-core/llama_index/core/chat_engine/context.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
from threading import Thread
from typing import Any, List, Optional, Tuple

from llama_index.core.async_utils import asyncio_run
from llama_index.core.base.base_retriever import BaseRetriever
from llama_index.core.base.llms.types import ChatMessage, MessageRole
from llama_index.core.callbacks import CallbackManager, trace_method
Expand Down Expand Up @@ -292,7 +292,7 @@ async def astream_chat(
source_nodes=nodes,
)
thread = Thread(
target=lambda x: asyncio.run(chat_response.awrite_response_to_history(x)),
target=lambda x: asyncio_run(chat_response.awrite_response_to_history(x)),
args=(self._memory,),
)
thread.start()
Expand Down
4 changes: 2 additions & 2 deletions llama-index-core/llama_index/core/chat_engine/simple.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
from threading import Thread
from typing import Any, List, Optional, Type

from llama_index.core.async_utils import asyncio_run
from llama_index.core.base.llms.types import ChatMessage
from llama_index.core.callbacks import CallbackManager, trace_method
from llama_index.core.chat_engine.types import (
Expand Down Expand Up @@ -167,7 +167,7 @@ async def astream_chat(
achat_stream=await self._llm.astream_chat(all_messages)
)
thread = Thread(
target=lambda x: asyncio.run(chat_response.awrite_response_to_history(x)),
target=lambda x: asyncio_run(chat_response.awrite_response_to_history(x)),
args=(self._memory,),
)
thread.start()
Expand Down
4 changes: 2 additions & 2 deletions llama-index-core/llama_index/core/evaluation/base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""Evaluator."""
import asyncio
from abc import abstractmethod
from typing import Any, Optional, Sequence

from llama_index.core.async_utils import asyncio_run
from llama_index.core.base.response.schema import Response
from llama_index.core.bridge.pydantic import BaseModel, Field
from llama_index.core.prompts.mixin import PromptMixin, PromptMixinType
Expand Down Expand Up @@ -59,7 +59,7 @@ def evaluate(
Subclasses can override this method to provide custom evaluation logic and
take in additional arguments.
"""
return asyncio.run(
return asyncio_run(
self.aevaluate(
query=query,
response=response,
Expand Down
8 changes: 4 additions & 4 deletions llama-index-core/llama_index/core/evaluation/batch_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from tenacity import retry, stop_after_attempt, wait_exponential
from typing import Any, Dict, List, Optional, Sequence, Tuple, cast

from llama_index.core.async_utils import asyncio_module
from llama_index.core.async_utils import asyncio_module, asyncio_run
from llama_index.core.base.base_query_engine import BaseQueryEngine
from llama_index.core.base.response.schema import RESPONSE_TYPE, Response
from llama_index.core.evaluation.base import BaseEvaluator, EvaluationResult
Expand Down Expand Up @@ -360,7 +360,7 @@ def evaluate_response_strs(
Sync version of aevaluate_response_strs.
"""
return asyncio.run(
return asyncio_run(
self.aevaluate_response_strs(
queries=queries,
response_strs=response_strs,
Expand All @@ -381,7 +381,7 @@ def evaluate_responses(
Sync version of aevaluate_responses.
"""
return asyncio.run(
return asyncio_run(
self.aevaluate_responses(
queries=queries,
responses=responses,
Expand All @@ -401,7 +401,7 @@ def evaluate_queries(
Sync version of aevaluate_queries.
"""
return asyncio.run(
return asyncio_run(
self.aevaluate_queries(
query_engine=query_engine,
queries=queries,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Coroutine, Dict, List, Optional, Tuple

from deprecated import deprecated
from llama_index.core.async_utils import asyncio_run
from llama_index.core import Document, ServiceContext, SummaryIndex
from llama_index.core.bridge.pydantic import BaseModel, Field
from llama_index.core.callbacks.base import CallbackManager
Expand Down Expand Up @@ -325,13 +326,13 @@ async def agenerate_dataset_from_nodes(

def generate_questions_from_nodes(self, num: int | None = None) -> List[str]:
"""Generates questions for each document."""
return asyncio.run(self.agenerate_questions_from_nodes(num=num))
return asyncio_run(self.agenerate_questions_from_nodes(num=num))

def generate_dataset_from_nodes(
self, num: int | None = None
) -> QueryResponseDataset:
"""Generates questions for each document."""
return asyncio.run(self.agenerate_dataset_from_nodes(num=num))
return asyncio_run(self.agenerate_dataset_from_nodes(num=num))

def _get_prompts(self) -> PromptDictType:
"""Get prompts."""
Expand Down
5 changes: 2 additions & 3 deletions llama-index-core/llama_index/core/evaluation/eval_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
"""

import asyncio
import subprocess
import tempfile
from collections import defaultdict
Expand All @@ -16,7 +15,7 @@
from llama_index_client import ProjectCreate
from llama_index_client.types.eval_question_create import EvalQuestionCreate

from llama_index.core.async_utils import asyncio_module
from llama_index.core.async_utils import asyncio_module, asyncio_run
from llama_index.core.base.base_query_engine import BaseQueryEngine
from llama_index.core.constants import DEFAULT_PROJECT_NAME
from llama_index.core.evaluation.base import EvaluationResult
Expand Down Expand Up @@ -46,7 +45,7 @@ def get_responses(
Sync version of aget_responses.
"""
return asyncio.run(aget_responses(*args, **kwargs))
return asyncio_run(aget_responses(*args, **kwargs))


def get_results_df(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple

from llama_index.core.async_utils import asyncio_run
from llama_index.core.bridge.pydantic import BaseModel, Field
from llama_index.core.evaluation.retrieval.metrics import resolve_metrics
from llama_index.core.evaluation.retrieval.metrics_base import (
Expand Down Expand Up @@ -123,7 +124,7 @@ def evaluate(
RetrievalEvalResult: Evaluation result
"""
return asyncio.run(
return asyncio_run(
self.aevaluate(
query=query,
expected_ids=expected_ids,
Expand Down
6 changes: 3 additions & 3 deletions llama-index-core/llama_index/core/extractors/interface.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"""Node parser interface."""
import asyncio
from abc import abstractmethod
from copy import deepcopy
from typing import Any, Dict, List, Optional, Sequence, cast

from llama_index.core.async_utils import asyncio_run
from llama_index.core.bridge.pydantic import Field
from llama_index.core.schema import (
BaseNode,
Expand Down Expand Up @@ -92,7 +92,7 @@ def extract(self, nodes: Sequence[BaseNode]) -> List[Dict]:
nodes (Sequence[Document]): nodes to extract metadata from
"""
return asyncio.run(self.aextract(nodes))
return asyncio_run(self.aextract(nodes))

async def aprocess_nodes(
self,
Expand Down Expand Up @@ -139,7 +139,7 @@ def process_nodes(
excluded_llm_metadata_keys: Optional[List[str]] = None,
**kwargs: Any,
) -> List[BaseNode]:
return asyncio.run(
return asyncio_run(
self.aprocess_nodes(
nodes,
excluded_embed_metadata_keys=excluded_embed_metadata_keys,
Expand Down
7 changes: 3 additions & 4 deletions llama-index-core/llama_index/core/llama_dataset/generator.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
"""Dataset generation from documents."""
from __future__ import annotations

import asyncio
import re
from typing import List, Optional

from llama_index.core import Document, ServiceContext, SummaryIndex
from llama_index.core.async_utils import DEFAULT_NUM_WORKERS, run_jobs
from llama_index.core.async_utils import DEFAULT_NUM_WORKERS, run_jobs, asyncio_run
from llama_index.core.base.response.schema import RESPONSE_TYPE
from llama_index.core.ingestion import run_transformations
from llama_index.core.llama_dataset import (
Expand Down Expand Up @@ -239,11 +238,11 @@ async def agenerate_dataset_from_nodes(self) -> LabelledRagDataset:

def generate_questions_from_nodes(self) -> LabelledRagDataset:
"""Generates questions but not the reference answers."""
return asyncio.run(self.agenerate_questions_from_nodes())
return asyncio_run(self.agenerate_questions_from_nodes())

def generate_dataset_from_nodes(self) -> LabelledRagDataset:
"""Generates questions for each document."""
return asyncio.run(self.agenerate_dataset_from_nodes())
return asyncio_run(self.agenerate_dataset_from_nodes())

def _get_prompts(self) -> PromptDictType:
"""Get prompts."""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import asyncio
import uuid
from abc import abstractmethod
from typing import Any, Dict, List, Optional, Sequence, Tuple, cast

import pandas as pd
from tqdm import tqdm

from llama_index.core.async_utils import DEFAULT_NUM_WORKERS, run_jobs
from llama_index.core.async_utils import DEFAULT_NUM_WORKERS, run_jobs, asyncio_run
from llama_index.core.base.response.schema import PydanticResponse
from llama_index.core.bridge.pydantic import BaseModel, Field, ValidationError
from llama_index.core.callbacks.base import CallbackManager
Expand Down Expand Up @@ -183,11 +182,10 @@ async def _get_table_output(table_context: str, summary_query_str: str) -> Any:
_get_table_output(table_context, self.summary_query_str)
for table_context in table_context_list
]
summary_outputs = asyncio.run(
run_jobs(
summary_jobs, show_progress=self.show_progress, workers=self.num_workers
)
summary_co = run_jobs(
summary_jobs, show_progress=self.show_progress, workers=self.num_workers
)
summary_outputs = asyncio_run(summary_co)
for element, summary_output in zip(elements, summary_outputs):
element.table_output = summary_output

Expand Down
4 changes: 2 additions & 2 deletions llama-index-core/tests/indices/query/test_compose_vector.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"""Test recursive queries."""

import asyncio
from typing import Any, Dict, List

import pytest
from llama_index.core.async_utils import asyncio_run
from llama_index.core.base.embeddings.base import BaseEmbedding
from llama_index.core.data_structs.data_structs import IndexStruct
from llama_index.core.indices.composability.graph import ComposableGraph
Expand Down Expand Up @@ -258,7 +258,7 @@ def test_recursive_query_vector_table_async(

query_engine = graph.as_query_engine(custom_query_engines=custom_query_engines)
task = query_engine.aquery("Cat?")
response = asyncio.run(task)
response = asyncio_run(task)
assert str(response) == ("Cat?:Cat?:This is a test v2.")


Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""Test response utils."""

import asyncio
from typing import List

from llama_index.core.async_utils import asyncio_run
from llama_index.core.constants import DEFAULT_CONTEXT_WINDOW, DEFAULT_NUM_OUTPUTS
from llama_index.core.indices.prompt_helper import PromptHelper
from llama_index.core.prompts.base import PromptTemplate
Expand Down Expand Up @@ -266,7 +266,7 @@ def test_accumulate_response_aget(
response_mode=ResponseMode.ACCUMULATE,
)

response = asyncio.run(
response = asyncio_run(
builder.aget_response(
text_chunks=texts,
query_str=query_str,
Expand Down

0 comments on commit ce95bd7

Please sign in to comment.