diff --git a/pytest.ini b/pytest.ini index 86c2d4c63..260f700fe 100644 --- a/pytest.ini +++ b/pytest.ini @@ -2,4 +2,4 @@ minversion = 3.7 log_cli=true python_files = test_*.py -addopts = -n auto --dist=loadscope \ No newline at end of file +;addopts = -n auto --dist=loadscope \ No newline at end of file diff --git a/src/superannotate/__init__.py b/src/superannotate/__init__.py index a10fb8b8d..04b84f412 100644 --- a/src/superannotate/__init__.py +++ b/src/superannotate/__init__.py @@ -20,6 +20,7 @@ from superannotate.lib.core import PACKAGE_VERSION_UPGRADE # noqa from superannotate.logger import get_default_logger # noqa from superannotate.version import __version__ # noqa +import superannotate.lib.core.enums as enums # noqa SESSIONS = {} @@ -27,6 +28,7 @@ "__version__", "SAClient", # Utils + "enums", "AppException", # analytics "class_distribution", diff --git a/src/superannotate/lib/app/interface/base_interface.py b/src/superannotate/lib/app/interface/base_interface.py index 06ff5de47..e7c38e87a 100644 --- a/src/superannotate/lib/app/interface/base_interface.py +++ b/src/superannotate/lib/app/interface/base_interface.py @@ -6,6 +6,7 @@ from types import FunctionType from typing import Iterable from typing import Sized +from typing import Tuple import lib.core as constants from lib.app.helpers import extract_project_folder @@ -21,41 +22,44 @@ class BaseInterfaceFacade: REGISTRY = [] - def __init__( - self, - token: str = None, - config_path: str = constants.CONFIG_PATH, - ): - env_token = os.environ.get("SA_TOKEN") - host = os.environ.get("SA_URL", constants.BACKEND_URL) + def __init__(self, token: str = None, config_path: str = None): version = os.environ.get("SA_VERSION", "v1") - ssl_verify = bool(os.environ.get("SA_SSL", True)) + _token, _config_path = None, None + _host = os.environ.get("SA_URL", constants.BACKEND_URL) + _ssl_verify = bool(os.environ.get("SA_SSL", True)) if token: - token = Controller.validate_token(token=token) - elif env_token: - host = os.environ.get("SA_URL", constants.BACKEND_URL) - token = Controller.validate_token(env_token) + _token = Controller.validate_token(token=token) + elif config_path: + _token, _host, _ssl_verify = self._retrieve_configs(config_path) else: - config_path = os.path.expanduser(str(config_path)) - if not Path(config_path).is_file() or not os.access(config_path, os.R_OK): - raise AppException( - f"SuperAnnotate config file {str(config_path)} not found." - f" Please provide correct config file location to sa.init() or use " - f"CLI's superannotate init to generate default location config file." + _token = os.environ.get("SA_TOKEN") + if not _token: + _toke, _host, _ssl_verify = self._retrieve_configs( + constants.CONFIG_PATH ) - config_repo = ConfigRepository(config_path) - main_endpoint = config_repo.get_one("main_endpoint").value - if not main_endpoint: - main_endpoint = constants.BACKEND_URL - token, host, ssl_verify = ( - Controller.validate_token(config_repo.get_one("token").value), - main_endpoint, - config_repo.get_one("ssl_verify").value, + self._token, self._host = _host, _token + self.controller = Controller(_token, _host, _ssl_verify, version) + + def __new__(cls, *args, **kwargs): + obj = super().__new__(cls, *args, **kwargs) + cls.REGISTRY.append(obj) + return obj + + @staticmethod + def _retrieve_configs(path) -> Tuple[str, str, str]: + config_path = os.path.expanduser(str(path)) + if not Path(config_path).is_file() or not os.access(config_path, os.R_OK): + raise AppException( + f"SuperAnnotate config file {str(config_path)} not found." + f" Please provide correct config file location to sa.init() or use " + f"CLI's superannotate init to generate default location config file." ) - self._host = host - self._token = token - self.controller = Controller(token, host, ssl_verify, version) - BaseInterfaceFacade.REGISTRY.append(self) + config_repo = ConfigRepository(config_path) + return ( + Controller.validate_token(config_repo.get_one("token").value), + config_repo.get_one("main_endpoint").value, + config_repo.get_one("ssl_verify").value, + ) @property def host(self): diff --git a/src/superannotate/lib/app/interface/sdk_interface.py b/src/superannotate/lib/app/interface/sdk_interface.py index fbc1d0e03..c2767823b 100644 --- a/src/superannotate/lib/app/interface/sdk_interface.py +++ b/src/superannotate/lib/app/interface/sdk_interface.py @@ -13,7 +13,7 @@ from typing import Union import boto3 -import lib.core as constances +import lib.core as constants from lib.app.annotation_helpers import add_annotation_bbox_to_json from lib.app.annotation_helpers import add_annotation_comment_to_json from lib.app.annotation_helpers import add_annotation_point_to_json @@ -62,6 +62,13 @@ class SAClient(BaseInterfaceFacade, metaclass=TrackableMeta): + def __init__( + self, + token: str = None, + config_path: str = constants.CONFIG_PATH, + ): + super().__init__(token, config_path) + def get_team_metadata(self): """Returns team metadata @@ -415,11 +422,11 @@ def copy_image( ).data if destination_project_metadata["project"].type in [ - constances.ProjectType.VIDEO.value, - constances.ProjectType.DOCUMENT.value, + constants.ProjectType.VIDEO.value, + constants.ProjectType.DOCUMENT.value, ] or source_project_metadata["project"].type in [ - constances.ProjectType.VIDEO.value, - constances.ProjectType.DOCUMENT.value, + constants.ProjectType.VIDEO.value, + constants.ProjectType.DOCUMENT.value, ]: raise AppException( LIMITED_FUNCTIONS[source_project_metadata["project"].type] @@ -738,8 +745,8 @@ def assign_images( project = self.controller.get_project_metadata(project_name).data if project["project"].type in [ - constances.ProjectType.VIDEO.value, - constances.ProjectType.DOCUMENT.value, + constants.ProjectType.VIDEO.value, + constants.ProjectType.DOCUMENT.value, ]: raise AppException(LIMITED_FUNCTIONS[project["project"].type]) @@ -864,12 +871,12 @@ def upload_images_from_folder_to_project( folder_path: Union[NotEmptyStr, Path], extensions: Optional[ Union[List[NotEmptyStr], Tuple[NotEmptyStr]] - ] = constances.DEFAULT_IMAGE_EXTENSIONS, + ] = constants.DEFAULT_IMAGE_EXTENSIONS, annotation_status="NotStarted", from_s3_bucket=None, exclude_file_patterns: Optional[ Iterable[NotEmptyStr] - ] = constances.DEFAULT_FILE_EXCLUDE_PATTERNS, + ] = constants.DEFAULT_FILE_EXCLUDE_PATTERNS, recursive_subfolders: Optional[StrictBool] = False, image_quality_in_editor: Optional[str] = None, ): @@ -926,7 +933,7 @@ def upload_images_from_folder_to_project( if exclude_file_patterns: exclude_file_patterns = list(exclude_file_patterns) + list( - constances.DEFAULT_FILE_EXCLUDE_PATTERNS + constants.DEFAULT_FILE_EXCLUDE_PATTERNS ) exclude_file_patterns = list(set(exclude_file_patterns)) @@ -1082,12 +1089,12 @@ def prepare_export( folders = folder_names if not annotation_statuses: annotation_statuses = [ - constances.AnnotationStatus.NOT_STARTED.name, - constances.AnnotationStatus.IN_PROGRESS.name, - constances.AnnotationStatus.QUALITY_CHECK.name, - constances.AnnotationStatus.RETURNED.name, - constances.AnnotationStatus.COMPLETED.name, - constances.AnnotationStatus.SKIPPED.name, + constants.AnnotationStatus.NOT_STARTED.name, + constants.AnnotationStatus.IN_PROGRESS.name, + constants.AnnotationStatus.QUALITY_CHECK.name, + constants.AnnotationStatus.RETURNED.name, + constants.AnnotationStatus.COMPLETED.name, + constants.AnnotationStatus.SKIPPED.name, ] response = self.controller.prepare_export( project_name=project_name, @@ -1106,7 +1113,7 @@ def upload_videos_from_folder_to_project( folder_path: Union[NotEmptyStr, Path], extensions: Optional[ Union[Tuple[NotEmptyStr], List[NotEmptyStr]] - ] = constances.DEFAULT_VIDEO_EXTENSIONS, + ] = constants.DEFAULT_VIDEO_EXTENSIONS, exclude_file_patterns: Optional[List[NotEmptyStr]] = (), recursive_subfolders: Optional[StrictBool] = False, target_fps: Optional[int] = None, @@ -1593,8 +1600,8 @@ def upload_preannotations_from_folder_to_project( project_folder_name = project_name + (f"/{folder_name}" if folder_name else "") project = self.controller.get_project_metadata(project_name).data if project["project"].type in [ - constances.ProjectType.VIDEO.value, - constances.ProjectType.DOCUMENT.value, + constants.ProjectType.VIDEO.value, + constants.ProjectType.DOCUMENT.value, ]: raise AppException(LIMITED_FUNCTIONS[project["project"].type]) if recursive_subfolders: @@ -1649,8 +1656,8 @@ def upload_image_annotations( project = self.controller.get_project_metadata(project_name).data if project["project"].type in [ - constances.ProjectType.VIDEO.value, - constances.ProjectType.DOCUMENT.value, + constants.ProjectType.VIDEO.value, + constants.ProjectType.DOCUMENT.value, ]: raise AppException(LIMITED_FUNCTIONS[project["project"].type]) @@ -1677,7 +1684,7 @@ def upload_image_annotations( mask=mask, verbose=verbose, ) - if response.errors and not response.errors == constances.INVALID_JSON_MESSAGE: + if response.errors and not response.errors == constants.INVALID_JSON_MESSAGE: raise AppException(response.errors) def download_model(self, model: MLModel, output_dir: Union[str, Path]): @@ -1735,8 +1742,8 @@ def benchmark( project = self.controller.get_project_metadata(project_name).data if project["project"].type in [ - constances.ProjectType.VIDEO.value, - constances.ProjectType.DOCUMENT.value, + constants.ProjectType.VIDEO.value, + constants.ProjectType.DOCUMENT.value, ]: raise AppException(LIMITED_FUNCTIONS[project["project"].type]) @@ -1887,8 +1894,8 @@ def add_annotation_bbox_to_image( project_name, folder_name = extract_project_folder(project) project = self.controller.get_project_metadata(project_name).data if project["project"].type in [ - constances.ProjectType.VIDEO.value, - constances.ProjectType.DOCUMENT.value, + constants.ProjectType.VIDEO.value, + constants.ProjectType.DOCUMENT.value, ]: raise AppException(LIMITED_FUNCTIONS[project["project"].type]) response = self.controller.get_annotations( @@ -1945,8 +1952,8 @@ def add_annotation_point_to_image( project_name, folder_name = extract_project_folder(project) project = self.controller.get_project_metadata(project_name).data if project["project"].type in [ - constances.ProjectType.VIDEO.value, - constances.ProjectType.DOCUMENT.value, + constants.ProjectType.VIDEO.value, + constants.ProjectType.DOCUMENT.value, ]: raise AppException(LIMITED_FUNCTIONS[project["project"].type]) response = self.controller.get_annotations( @@ -2000,8 +2007,8 @@ def add_annotation_comment_to_image( project_name, folder_name = extract_project_folder(project) project = self.controller.get_project_metadata(project_name).data if project["project"].type in [ - constances.ProjectType.VIDEO.value, - constances.ProjectType.DOCUMENT.value, + constants.ProjectType.VIDEO.value, + constants.ProjectType.DOCUMENT.value, ]: raise AppException(LIMITED_FUNCTIONS[project["project"].type]) response = self.controller.get_annotations( @@ -2187,8 +2194,8 @@ def aggregate_annotations_as_df( :rtype: pandas DataFrame """ if project_type in ( - constances.ProjectType.VECTOR.name, - constances.ProjectType.PIXEL.name, + constants.ProjectType.VECTOR.name, + constants.ProjectType.PIXEL.name, ): from superannotate.lib.app.analytics.common import ( aggregate_image_annotations_as_df, @@ -2202,8 +2209,8 @@ def aggregate_annotations_as_df( folder_names=folder_names, ) elif project_type in ( - constances.ProjectType.VIDEO.name, - constances.ProjectType.DOCUMENT.name, + constants.ProjectType.VIDEO.name, + constants.ProjectType.DOCUMENT.name, ): from superannotate.lib.app.analytics.aggregators import DataAggregator @@ -2214,21 +2221,21 @@ def aggregate_annotations_as_df( ).aggregate_annotations_as_df() def delete_annotations( - self, project: NotEmptyStr, image_names: Optional[List[NotEmptyStr]] = None + self, project: NotEmptyStr, item_names: Optional[List[NotEmptyStr]] = None ): """ - Delete image annotations from a given list of images. + Delete item annotations from a given list of items. :param project: project name or folder path (e.g., "project1/folder1") :type project: str - :param image_names: image names. If None, all image annotations from a given project/folder will be deleted. - :type image_names: list of strs + :param item_names: image names. If None, all image annotations from a given project/folder will be deleted. + :type item_names: list of strs """ project_name, folder_name = extract_project_folder(project) response = self.controller.delete_annotations( - project_name=project_name, folder_name=folder_name, item_names=image_names + project_name=project_name, folder_name=folder_name, item_names=item_names ) if response.errors: raise AppException(response.errors) diff --git a/src/superannotate/lib/app/interface/types.py b/src/superannotate/lib/app/interface/types.py index 5103468cd..780c88a83 100644 --- a/src/superannotate/lib/app/interface/types.py +++ b/src/superannotate/lib/app/interface/types.py @@ -14,6 +14,7 @@ from pydantic import BaseModel from pydantic import conlist from pydantic import constr +from pydantic import errors from pydantic import Extra from pydantic import Field from pydantic import parse_obj_as @@ -21,11 +22,23 @@ from pydantic import StrictStr from pydantic import validate_arguments as pydantic_validate_arguments from pydantic import ValidationError +from pydantic.errors import PydanticTypeError from pydantic.errors import StrRegexError NotEmptyStr = constr(strict=True, min_length=1) +class EnumMemberError(PydanticTypeError): + code = "enum" + + def __str__(self) -> str: + permitted = ", ".join(str(v.name) for v in self.enum_values) # type: ignore + return f"Available values are: {permitted}" + + +errors.EnumMemberError = EnumMemberError + + class EmailStr(StrictStr): @classmethod def validate(cls, value: Union[str]) -> Union[str]: diff --git a/src/superannotate/lib/core/enums.py b/src/superannotate/lib/core/enums.py index 62a5bcc58..edefb6027 100644 --- a/src/superannotate/lib/core/enums.py +++ b/src/superannotate/lib/core/enums.py @@ -8,8 +8,13 @@ def __new__(cls, title, value): obj._value_ = value obj.__doc__ = title obj._type = "titled_enum" + cls._value2member_map_[title] = obj return obj + @classmethod + def choices(cls): + return tuple(cls._value2member_map_.keys()) + @DynamicClassAttribute def name(self) -> str: return self.__doc__ @@ -41,6 +46,9 @@ def titles(cls): def equals(self, other: Enum): return self.__doc__.lower() == other.__doc__.lower() + def __eq__(self, other): + return super().__eq__(other) + class AnnotationTypes(str, Enum): BBOX = "bbox" diff --git a/src/superannotate/lib/core/serviceproviders.py b/src/superannotate/lib/core/serviceproviders.py index 7a9868912..7a70f2fe1 100644 --- a/src/superannotate/lib/core/serviceproviders.py +++ b/src/superannotate/lib/core/serviceproviders.py @@ -337,7 +337,10 @@ async def download_annotations( postfix: str, items: List[str] = None, callback: Callable = None, - ) -> List[dict]: + ) -> int: + """ + Returns the number of items downloaded + """ raise NotImplementedError def upload_priority_scores( diff --git a/src/superannotate/lib/core/usecases/annotations.py b/src/superannotate/lib/core/usecases/annotations.py index b64bd9bf8..2a12bdb56 100644 --- a/src/superannotate/lib/core/usecases/annotations.py +++ b/src/superannotate/lib/core/usecases/annotations.py @@ -809,8 +809,9 @@ def get_items_count(path: str): def coroutine_wrapper(coroutine): loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) - loop.run_until_complete(coroutine) + count = loop.run_until_complete(coroutine) loop.close() + return count def execute(self): if self.is_valid(): @@ -842,7 +843,7 @@ def execute(self): if not folders: loop = asyncio.new_event_loop() - loop.run_until_complete( + count = loop.run_until_complete( self._backend_client.download_annotations( team_id=self._project.team_id, project_id=self._project.id, @@ -870,12 +871,12 @@ def execute(self): callback=self._callback, ) ) - _ = [_ for _ in executor.map(self.coroutine_wrapper, coroutines)] + count = sum( + [i for i in executor.map(self.coroutine_wrapper, coroutines)] + ) self.reporter.stop_spinner() - self.reporter.log_info( - f"Downloaded annotations for {self.get_items_count(export_path)} items." - ) + self.reporter.log_info(f"Downloaded annotations for {count} items.") self.download_annotation_classes(export_path) self._response.data = os.path.abspath(export_path) return self._response diff --git a/src/superannotate/lib/infrastructure/services.py b/src/superannotate/lib/infrastructure/services.py index 95e3c896e..8297f9ada 100644 --- a/src/superannotate/lib/infrastructure/services.py +++ b/src/superannotate/lib/infrastructure/services.py @@ -1086,7 +1086,7 @@ async def download_annotations( postfix: str, items: List[str] = None, callback: Callable = None, - ) -> List[dict]: + ) -> int: import aiohttp async with aiohttp.ClientSession( diff --git a/src/superannotate/lib/infrastructure/stream_data_handler.py b/src/superannotate/lib/infrastructure/stream_data_handler.py index ad12f2175..675701501 100644 --- a/src/superannotate/lib/infrastructure/stream_data_handler.py +++ b/src/superannotate/lib/infrastructure/stream_data_handler.py @@ -109,7 +109,11 @@ async def download_data( method: str = "post", params=None, chunk_size: int = 100, - ): + ) -> int: + """ + Returns the number of items downloaded + """ + items_downloaded: int = 0 if chunk_size and data: for i in range(0, len(data), chunk_size): data_to_process = data[i : i + chunk_size] @@ -126,6 +130,7 @@ async def download_data( annotation, self._callback, ) + items_downloaded += 1 else: async for annotation in self.fetch( method, session, url, self._process_data(data), params=params @@ -133,3 +138,5 @@ async def download_data( self._store_annotation( download_path, postfix, annotation, self._callback ) + items_downloaded += 1 + return items_downloaded diff --git a/tests/integration/annotations/test_download_annotations.py b/tests/integration/annotations/test_download_annotations.py index 1083d2c76..b32709515 100644 --- a/tests/integration/annotations/test_download_annotations.py +++ b/tests/integration/annotations/test_download_annotations.py @@ -60,11 +60,11 @@ def test_download_annotations_from_folders(self): f"{self.PROJECT_NAME}{'/' + folder if folder else ''}", self.folder_path ) with tempfile.TemporaryDirectory() as temp_dir: - annotations_path = sa.download_annotations(f"{self.PROJECT_NAME}", temp_dir) - self.assertEqual(len(os.listdir(annotations_path)), 5) + annotations_path = sa.download_annotations(f"{self.PROJECT_NAME}", temp_dir, recursive=True) + self.assertEqual(len(os.listdir(annotations_path)), 7) @pytest.mark.flaky(reruns=3) - def test_download_annotations_from_folders(self): + def test_download_empty_annotations_from_folders(self): sa.create_folder(self.PROJECT_NAME, self.FOLDER_NAME) sa.create_folder(self.PROJECT_NAME, self.FOLDER_NAME_2) sa.create_annotation_classes_from_classes_json( diff --git a/tests/integration/test_single_annotation_download.py b/tests/integration/test_single_annotation_download.py index fb00b50c5..6e2f4a3ab 100644 --- a/tests/integration/test_single_annotation_download.py +++ b/tests/integration/test_single_annotation_download.py @@ -1,15 +1,16 @@ - import filecmp import json import os import tempfile from os.path import dirname + import pytest from src.superannotate import SAClient -sa = SAClient() from tests.integration.base import BaseTestCase +sa = SAClient() + class TestSingleAnnotationDownloadUpload(BaseTestCase): PROJECT_NAME = "test_single_annotation" @@ -68,7 +69,7 @@ def test_annotation_download_upload_vector(self): ) ) # TODO: - #assert downloaded_json == uploaded_json + # assert downloaded_json == uploaded_json class TestSingleAnnotationDownloadUploadPixel(BaseTestCase): diff --git a/tests/unit/test_enum_arguments_handeling.py b/tests/unit/test_enum_arguments_handeling.py new file mode 100644 index 000000000..7ff71fbb8 --- /dev/null +++ b/tests/unit/test_enum_arguments_handeling.py @@ -0,0 +1,20 @@ +# from typing import Literal +from pydantic.typing import Literal + +from superannotate import enums +from superannotate import SAClient +from superannotate.lib.app.interface.types import validate_arguments + + + +@validate_arguments +def foo(status: enums.ProjectStatus): + return status + + +def test_enum_arg(): + SAClient() + assert foo(1) == 1 + assert foo("NotStarted") == 1 + assert foo(enums.ProjectStatus.NotStarted.name) == 1 + assert foo(enums.ProjectStatus.NotStarted.value) == 1