diff --git a/src/superannotate/lib/core/usecases/__init__.py b/src/superannotate/lib/core/usecases/__init__.py index 7d34f674e..d8a4f0264 100644 --- a/src/superannotate/lib/core/usecases/__init__.py +++ b/src/superannotate/lib/core/usecases/__init__.py @@ -7,3 +7,6 @@ from lib.core.usecases.items import * # noqa: F403 F401 from lib.core.usecases.models import * # noqa: F403 F401 from lib.core.usecases.projects import * # noqa: F403 F401 + +import nest_asyncio +nest_asyncio.apply() \ No newline at end of file diff --git a/src/superannotate/lib/core/usecases/annotations.py b/src/superannotate/lib/core/usecases/annotations.py index 9280a90f9..a59e4fe58 100644 --- a/src/superannotate/lib/core/usecases/annotations.py +++ b/src/superannotate/lib/core/usecases/annotations.py @@ -26,7 +26,6 @@ import boto3 import jsonschema.validators import lib.core as constants -import nest_asyncio from jsonschema import Draft7Validator from jsonschema import ValidationError from lib.core.conditions import Condition @@ -391,7 +390,6 @@ def execute(self): len(items_to_upload), description="Uploading Annotations" ) try: - nest_asyncio.apply() asyncio.run(self.run_workers(items_to_upload)) except Exception: logger.debug(traceback.format_exc()) @@ -737,7 +735,6 @@ def execute(self): except KeyError: missing_annotations.append(name) try: - nest_asyncio.apply() asyncio.run(self.run_workers(items_to_upload)) except Exception as e: logger.debug(e) @@ -935,7 +932,6 @@ def execute(self): json.dump(annotation_json, annotation_file) size = annotation_file.tell() annotation_file.seek(0) - nest_asyncio.apply() if size > BIG_FILE_THRESHOLD: uploaded = asyncio.run( self._service_provider.annotations.upload_big_annotation( @@ -1550,7 +1546,6 @@ def execute(self): large_items = list(filter(lambda item: item.id in large_item_ids, items)) small_items = list(filter(lambda item: item.id in small_items_ids, items)) try: - nest_asyncio.apply() annotations = asyncio.run(self.run_workers(large_items, small_items)) except Exception as e: logger.error(e) @@ -1735,7 +1730,6 @@ def execute(self): ).data if not folders: folders.append(self._folder) - nest_asyncio.apply() for folder in folders: if self._item_names: items = get_or_raise( diff --git a/tests/unit/test_async_functions.py b/tests/unit/test_async_functions.py index 8d9ced027..af4006291 100644 --- a/tests/unit/test_async_functions.py +++ b/tests/unit/test_async_functions.py @@ -7,6 +7,25 @@ sa = SAClient() +class DummyIterator: + def __init__(self, delay, to): + self.delay = delay + self.i = 0 + self.to = to + + def __aiter__(self): + return self + + async def __anext__(self): + i = self.i + if i >= self.to: + raise StopAsyncIteration + self.i += 1 + if i: + await asyncio.sleep(self.delay) + return i + + class TestAsyncFunctions(TestCase): PROJECT_NAME = "TestAsync" PROJECT_DESCRIPTION = "Desc" @@ -26,31 +45,35 @@ def setUpClass(cls): def tearDownClass(cls): sa.delete_project(cls.PROJECT_NAME) + @staticmethod + async def nested(): + annotations = sa.get_annotations(TestAsyncFunctions.PROJECT_NAME) + assert len(annotations) == 4 + def test_get_annotations_in_running_event_loop(self): async def _test(): annotations = sa.get_annotations(self.PROJECT_NAME) assert len(annotations) == 4 asyncio.run(_test()) - def test_multiple_get_annotations_in_running_event_loop(self): - # TODO add handling of nested loop - async def nested(): - sa.attach_items(self.PROJECT_NAME, self.ATTACH_PAYLOAD) - annotations = sa.get_annotations(self.PROJECT_NAME) - assert len(annotations) == 4 - async def create_task_test(): - import nest_asyncio - nest_asyncio.apply() - task1 = asyncio.create_task(nested()) - task2 = asyncio.create_task(nested()) + def test_create_task_get_annotations_in_running_event_loop(self): + async def _test(): + task1 = asyncio.create_task(self.nested()) + task2 = asyncio.create_task(self.nested()) await task1 await task2 - asyncio.run(create_task_test()) + asyncio.run(_test()) + + def test_gather_get_annotations_in_running_event_loop(self): + async def gather_test(): + await asyncio.gather(self.nested(), self.nested()) + asyncio.run(gather_test()) + def test_gather_async_for(self): async def gather_test(): - import nest_asyncio - nest_asyncio.apply() - await asyncio.gather(nested(), nested()) + async for _ in DummyIterator(delay=0.01, to=2): + annotations = sa.get_annotations(TestAsyncFunctions.PROJECT_NAME) + assert len(annotations) == 4 asyncio.run(gather_test()) def test_upload_annotations_in_running_event_loop(self):