diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 3957f36ba..2efb16869 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -7,13 +7,13 @@ on: jobs: build-n-publish: name: Build and publish Python 🐍 distributions 📦 to PyPI and TestPyPI - runs-on: ubuntu-18.04 + runs-on: ubuntu-20.04 steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - name: Set up Python - uses: actions/setup-python@v1 + uses: actions/setup-python@v4 with: - python-version: "3.7" + python-version: "3.8" - name: Upgrade pip run: >- python -m @@ -36,7 +36,7 @@ jobs: . - name: Publish distribution 📦 to PyPI if: startsWith(github.ref, 'refs/tags') - uses: pypa/gh-action-pypi-publish@master + uses: pypa/gh-action-pypi-publish@release/v1 with: password: ${{ secrets.pypi_password }} verbose: true diff --git a/CHANGELOG.md b/CHANGELOG.md index 7502eec00..f29c383d0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,12 @@ # Changelog All release highlights of this project will be documented in this file. +## 4.4.11 - April 2, 2023 +### Added +- `SAClient.set_project_status()` method. +- `SAClient.set_folder_status()` method. +### Updated +- `SAClient.create_annotation_class()` added OCR type attribute group support in the vector projects. +- `SAClient.create_annotation_classes_from_classes_json()` added OCR type attribute group support in the vector projects. ## 4.4.10 - March 12, 2023 ### Updated - Configuration file creation flow diff --git a/docs/source/api_reference/api_client.rst b/docs/source/api_reference/api_client.rst index 38a59902d..fde8729c7 100644 --- a/docs/source/api_reference/api_client.rst +++ b/docs/source/api_reference/api_client.rst @@ -9,6 +9,7 @@ Contents :maxdepth: 8 api_project + api_folder api_item api_annotation api_annotation_class diff --git a/docs/source/api_reference/api_folder.rst b/docs/source/api_reference/api_folder.rst new file mode 100644 index 000000000..d4890066c --- /dev/null +++ b/docs/source/api_reference/api_folder.rst @@ -0,0 +1,11 @@ +======= +Folders +======= + +.. automethod:: superannotate.SAClient.search_folders +.. automethod:: superannotate.SAClient.assign_folder +.. automethod:: superannotate.SAClient.unassign_folder +.. automethod:: superannotate.SAClient.get_folder_by_id +.. automethod:: superannotate.SAClient.get_folder_metadata +.. automethod:: superannotate.SAClient.create_folder +.. automethod:: superannotate.SAClient.delete_folders diff --git a/docs/source/api_reference/api_image.rst b/docs/source/api_reference/api_image.rst index aca87a58e..12bcb0f25 100644 --- a/docs/source/api_reference/api_image.rst +++ b/docs/source/api_reference/api_image.rst @@ -1,6 +1,6 @@ -========== +====== Images -========== +====== .. _ref_search_images: diff --git a/docs/source/api_reference/api_project.rst b/docs/source/api_reference/api_project.rst index 626dab2b3..ca6e23344 100644 --- a/docs/source/api_reference/api_project.rst +++ b/docs/source/api_reference/api_project.rst @@ -1,25 +1,19 @@ -========== +======== Projects -========== +======== .. _ref_projects: .. _ref_search_projects: -.. automethod:: superannotate.SAClient.search_projects .. automethod:: superannotate.SAClient.create_project +.. automethod:: superannotate.SAClient.search_projects .. automethod:: superannotate.SAClient.create_project_from_metadata .. automethod:: superannotate.SAClient.clone_project -.. automethod:: superannotate.SAClient.delete_project .. automethod:: superannotate.SAClient.rename_project +.. automethod:: superannotate.SAClient.delete_project .. _ref_get_project_metadata: .. automethod:: superannotate.SAClient.get_project_by_id +.. automethod:: superannotate.SAClient.set_project_status .. automethod:: superannotate.SAClient.get_project_metadata .. automethod:: superannotate.SAClient.get_project_image_count -.. automethod:: superannotate.SAClient.search_folders -.. automethod:: superannotate.SAClient.assign_folder -.. automethod:: superannotate.SAClient.unassign_folder -.. automethod:: superannotate.SAClient.get_folder_by_id -.. automethod:: superannotate.SAClient.get_folder_metadata -.. automethod:: superannotate.SAClient.create_folder -.. automethod:: superannotate.SAClient.delete_folders .. automethod:: superannotate.SAClient.upload_images_to_project .. automethod:: superannotate.SAClient.attach_items_from_integrated_storage .. automethod:: superannotate.SAClient.upload_image_to_project @@ -30,5 +24,5 @@ Projects .. automethod:: superannotate.SAClient.add_contributors_to_project .. automethod:: superannotate.SAClient.get_project_settings .. automethod:: superannotate.SAClient.set_project_default_image_quality_in_editor -.. automethod:: superannotate.SAClient.get_project_workflow .. automethod:: superannotate.SAClient.set_project_workflow +.. automethod:: superannotate.SAClient.get_project_workflow diff --git a/docs/source/api_reference/api_team.rst b/docs/source/api_reference/api_team.rst index cbf335851..d78ea31e3 100644 --- a/docs/source/api_reference/api_team.rst +++ b/docs/source/api_reference/api_team.rst @@ -1,6 +1,6 @@ -========== +==== Team -========== +==== .. automethod:: superannotate.SAClient.get_team_metadata diff --git a/requirements.txt b/requirements.txt index a06d93f40..e7777454d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -17,7 +17,6 @@ mixpanel==4.8.3 pydantic>=1.10.4 setuptools>=57.4.0 email-validator>=1.0.3 -nest-asyncio==1.5.4 jsonschema==3.2.0 pandas>=1.1.4 aiofiles==0.8.0 diff --git a/src/superannotate/__init__.py b/src/superannotate/__init__.py index 0838c956d..5c1b72fe8 100644 --- a/src/superannotate/__init__.py +++ b/src/superannotate/__init__.py @@ -3,7 +3,8 @@ import sys import typing -__version__ = "4.4.10" +__version__ = "4.4.11" + sys.path.append(os.path.split(os.path.realpath(__file__))[0]) diff --git a/src/superannotate/lib/app/interface/base_interface.py b/src/superannotate/lib/app/interface/base_interface.py index 5d4e141a5..930d84e03 100644 --- a/src/superannotate/lib/app/interface/base_interface.py +++ b/src/superannotate/lib/app/interface/base_interface.py @@ -1,6 +1,7 @@ import functools import json import os +import platform import sys import typing from inspect import signature @@ -33,7 +34,7 @@ def __init__(self, token: TokenStr = None, config_path: str = None): if token: config = ConfigEntity(SA_TOKEN=token) elif config_path: - config_path = Path(config_path) + config_path = Path(config_path).expanduser() if not Path(config_path).is_file() or not os.access( config_path, os.R_OK ): @@ -124,10 +125,10 @@ def _retrieve_configs_from_env() -> typing.Union[ConfigEntity, None]: class Tracker: def get_mp_instance(self) -> Mixpanel: client = self.get_client() - mp_token = "ca95ed96f80e8ec3be791e2d3097cf51" - if client: - if client.host != constants.BACKEND_URL: - mp_token = "e741d4863e7e05b1a45833d01865ef0d" + if client.controller._config.API_URL == constants.BACKEND_URL: # noqa + mp_token = "ca95ed96f80e8ec3be791e2d3097cf51" + else: + mp_token = "e741d4863e7e05b1a45833d01865ef0d" return Mixpanel(mp_token) @staticmethod @@ -137,6 +138,8 @@ def get_default_payload(team_name, user_id): "Team": team_name, "Team Owner": user_id, "Version": __version__, + "Python version": platform.python_version(), + "Python interpreter type": platform.python_implementation(), } def __init__(self, function): @@ -170,6 +173,10 @@ def default_parser(function_name: str, kwargs: dict) -> tuple: for key, value in kwargs.items(): if key == "self": continue + elif key == "token": + properties["sa_token"] = str(bool(value)) + elif key == "config_path": + properties[key] = str(bool(value)) elif value is None: properties[key] = value elif key == "project": @@ -241,5 +248,6 @@ def __new__(mcs, name, bases, attrs): attr_value, FunctionType ) and not attr_value.__name__.startswith("_"): attrs[attr_name] = Tracker(validate_arguments(attr_value)) + attrs["__init__"] = Tracker(validate_arguments(attrs["__init__"])) tmp = super().__new__(mcs, name, bases, attrs) return tmp diff --git a/src/superannotate/lib/app/interface/sdk_interface.py b/src/superannotate/lib/app/interface/sdk_interface.py index 6566d961a..b2a135187 100644 --- a/src/superannotate/lib/app/interface/sdk_interface.py +++ b/src/superannotate/lib/app/interface/sdk_interface.py @@ -2,6 +2,7 @@ import copy import io import json +import logging import os import sys import warnings @@ -63,7 +64,7 @@ from lib.core.types import Project from lib.infrastructure.utils import extract_project_folder from lib.infrastructure.validators import wrap_error -import logging + logger = logging.getLogger("sa") @@ -71,7 +72,7 @@ NotEmptyStr = TypeVar("NotEmptyStr", bound=constr(strict=True, min_length=1)) -PROJECT_STATUS = Literal["Undefined", "NotStarted", "InProgress", "Completed", "OnHold"] +PROJECT_STATUS = Literal["NotStarted", "InProgress", "Completed", "OnHold"] PROJECT_TYPE = Literal[ "Vector", "Pixel", "Video", "Document", "Tiled", "Other", "PointCloud" @@ -91,13 +92,7 @@ ANNOTATOR_ROLE = Literal["Admin", "Annotator", "QA"] -FOLDER_STATUS = Literal[ - "Undefined", - "NotStarted", - "InProgress", - "Completed", - "OnHold", -] +FOLDER_STATUS = Literal["NotStarted", "InProgress", "Completed", "OnHold"] class Setting(TypedDict): @@ -782,6 +777,52 @@ def search_annotation_classes( for i in response.data ] + def set_project_status(self, project: NotEmptyStr, status: PROJECT_STATUS): + """Set project status + + :param project: project name + :type project: str + :param status: status to set, should be one of. \n + ♦ “NotStarted” \n + ♦ “InProgress” \n + ♦ “Completed” \n + ♦ “OnHold” \n + :type status: str + """ + project = self.controller.get_project(name=project) + project.status = constants.ProjectStatus.get_value(status) + response = self.controller.projects.update(project) + if response.errors: + raise AppException(f"Failed to change {project.name} status.") + logger.info(f"Successfully updated {project.name} status to {status}") + + def set_folder_status( + self, project: NotEmptyStr, folder: NotEmptyStr, status: FOLDER_STATUS + ): + """Set folder status + + :param project: project name + :type project: str + :param folder: folder name + :type folder: str + :param status: status to set, should be one of. \n + ♦ “NotStarted” \n + ♦ “InProgress” \n + ♦ “Completed” \n + ♦ “OnHold” \n + :type status: str + """ + project, folder = self.controller.get_project_folder( + project_name=project, folder_name=folder + ) + folder.status = constants.FolderStatus.get_value(status) + response = self.controller.update(project, folder) + if response.errors: + raise AppException(f"Failed to change {project.name}/{folder.name} status.") + logger.info( + f"Successfully updated {project.name}/{folder.name} status to {status}" + ) + def set_project_default_image_quality_in_editor( self, project: Union[NotEmptyStr, dict], @@ -1360,7 +1401,8 @@ def create_annotation_class( :type color: str :param attribute_groups: list of attribute group dicts. - The values for the "group_type" key are "radio"|"checklist"|"text"|"numeric". + The values for the "group_type" key are "radio"|"checklist"|"text"|"numeric"|"ocr". + "ocr "group_type" key is only available for Vector projects. Mandatory keys for each attribute group are - "name" @@ -2340,7 +2382,7 @@ def query( subset: Optional[NotEmptyStr] = None, ): """Return items that satisfy the given query. - Query syntax should be in SuperAnnotate query language(https://doc.superannotate.com/docs/query-search-1). + Query syntax should be in SuperAnnotate query language(https://doc.superannotate.com/docs/explore-overview). :param project: project name or folder path (e.g., “project1/folder1”) :type project: str diff --git a/src/superannotate/lib/core/__init__.py b/src/superannotate/lib/core/__init__.py index 7322412e4..3354cf264 100644 --- a/src/superannotate/lib/core/__init__.py +++ b/src/superannotate/lib/core/__init__.py @@ -16,7 +16,6 @@ from lib.core.enums import UploadState from lib.core.enums import UserRole - CONFIG = Config() BACKEND_URL = "https://api.superannotate.com" HOME_PATH = expanduser("~/.superannotate") @@ -33,33 +32,32 @@ def setup_logging(level=DEFAULT_LOGGING_LEVEL, file_path=LOG_FILE_LOCATION): - logger = logging.getLogger("sa") for handler in logger.handlers[:]: # remove all old handlers logger.removeHandler(handler) - logger.propagate = True + logger.propagate = False logger.setLevel(level) stream_handler = logging.StreamHandler() formatter = Formatter("SA-PYTHON-SDK - %(levelname)s - %(message)s") stream_handler.setFormatter(formatter) logger.addHandler(stream_handler) try: + os.makedirs(file_path, exist_ok=True) log_file_path = os.path.join(file_path, "sa.log") - open(log_file_path, "w").close() - if os.access(log_file_path, os.W_OK): - file_handler = RotatingFileHandler( - log_file_path, - maxBytes=5 * 1024 * 1024, - backupCount=5, - mode="a", - ) - file_formatter = Formatter( - "SA-PYTHON-SDK - %(levelname)s - %(asctime)s - %(message)s" - ) - file_handler.setFormatter(file_formatter) - logger.addHandler(file_handler) + file_handler = RotatingFileHandler( + log_file_path, + maxBytes=5 * 1024 * 1024, + backupCount=5, + mode="a", + ) + file_formatter = Formatter( + "SA-PYTHON-SDK - %(levelname)s - %(asctime)s - %(message)s" + ) + file_handler.setFormatter(file_formatter) + logger.addHandler(file_handler) + except OSError as e: - logging.error(e) + logger.debug(e) DEFAULT_IMAGE_EXTENSIONS = ["jpg", "jpeg", "png", "tif", "tiff", "webp", "bmp"] diff --git a/src/superannotate/lib/core/entities/project.py b/src/superannotate/lib/core/entities/project.py index 7a450490c..1ae6693b1 100644 --- a/src/superannotate/lib/core/entities/project.py +++ b/src/superannotate/lib/core/entities/project.py @@ -137,6 +137,9 @@ def __copy__(self): upload_state=self.upload_state, ) + def __eq__(self, other): + return self.id == other.id + class MLModelEntity(TimedBaseModel): id: Optional[int] diff --git a/src/superannotate/lib/core/serviceproviders.py b/src/superannotate/lib/core/serviceproviders.py index 21bf0d258..53ac6f798 100644 --- a/src/superannotate/lib/core/serviceproviders.py +++ b/src/superannotate/lib/core/serviceproviders.py @@ -332,10 +332,9 @@ async def list_small_annotations( raise NotImplementedError @abstractmethod - def sort_items_by_size( + def get_upload_chunks( self, project: entities.ProjectEntity, - folder: entities.FolderEntity, item_ids: List[int], ) -> Dict[str, List]: raise NotImplementedError diff --git a/src/superannotate/lib/core/usecases/annotations.py b/src/superannotate/lib/core/usecases/annotations.py index 9280a90f9..e8ee66274 100644 --- a/src/superannotate/lib/core/usecases/annotations.py +++ b/src/superannotate/lib/core/usecases/annotations.py @@ -9,6 +9,7 @@ import re import time import traceback +import typing from dataclasses import dataclass from datetime import datetime from itertools import islice @@ -26,7 +27,6 @@ import boto3 import jsonschema.validators import lib.core as constants -import nest_asyncio from jsonschema import Draft7Validator from jsonschema import ValidationError from lib.core.conditions import Condition @@ -58,6 +58,20 @@ URI_THRESHOLD = 4 * 1024 - 120 +def run_async(f): + from threading import Thread + + response = [None] + + def foo(f: typing.Callable, res): + res[0] = asyncio.run(f) # noqa + + thread = Thread(target=foo, args=(f, response)) + thread.start() + thread.join() + return response[0] + + @dataclass class Report: failed_annotations: list @@ -391,8 +405,7 @@ def execute(self): len(items_to_upload), description="Uploading Annotations" ) try: - nest_asyncio.apply() - asyncio.run(self.run_workers(items_to_upload)) + run_async(self.run_workers(items_to_upload)) except Exception: logger.debug(traceback.format_exc()) self._response.errors = AppException("Can't upload annotations.") @@ -737,8 +750,7 @@ def execute(self): except KeyError: missing_annotations.append(name) try: - nest_asyncio.apply() - asyncio.run(self.run_workers(items_to_upload)) + run_async(self.run_workers(items_to_upload)) except Exception as e: logger.debug(e) self._response.errors = AppException("Can't upload annotations.") @@ -935,9 +947,8 @@ def execute(self): json.dump(annotation_json, annotation_file) size = annotation_file.tell() annotation_file.seek(0) - nest_asyncio.apply() if size > BIG_FILE_THRESHOLD: - uploaded = asyncio.run( + uploaded = run_async( self._service_provider.annotations.upload_big_annotation( project=self._project, folder=self._folder, @@ -949,7 +960,7 @@ def execute(self): if not uploaded: self._response.errors = constants.INVALID_JSON_MESSAGE else: - response = asyncio.run( + response = run_async( self._service_provider.annotations.upload_small_annotations( project=self._project, folder=self._folder, @@ -1391,7 +1402,6 @@ def __init__( self._item_names = item_names self._item_names_provided = True self._big_annotations_queue = None - self._small_annotations_queue = None def validate_project_type(self): if self._project.type == constants.ProjectType.PIXEL.value: @@ -1440,29 +1450,18 @@ async def get_big_annotation(self): break return large_annotations - async def get_small_annotations(self): - small_annotations = [] - while True: - items = await self._small_annotations_queue.get() - if items: - annotations = ( - await self._service_provider.annotations.list_small_annotations( - project=self._project, - folder=self._folder, - item_ids=[i.id for i in items], - reporter=self.reporter, - ) - ) - small_annotations.extend(annotations) - else: - await self._small_annotations_queue.put(None) - break - return small_annotations + async def get_small_annotations(self, item_ids: List[int]): + return await self._service_provider.annotations.list_small_annotations( + project=self._project, + folder=self._folder, + item_ids=item_ids, + reporter=self.reporter, + ) async def run_workers( self, big_annotations: List[BaseItemEntity], - small_annotations: List[BaseItemEntity], + small_annotations: List[List[Dict]], ): annotations = [] if big_annotations: @@ -1485,26 +1484,16 @@ async def run_workers( ) ) if small_annotations: - self._small_annotations_queue = asyncio.Queue() - small_chunks = divide_to_chunks( - small_annotations, size=self._config.ANNOTATION_CHUNK_SIZE - ) - for chunk in small_chunks: - self._small_annotations_queue.put_nowait(chunk) - self._small_annotations_queue.put_nowait(None) - - annotations.extend( - list( - itertools.chain.from_iterable( - await asyncio.gather( - *[ - self.get_small_annotations() - for _ in range(self._config.MAX_COROUTINE_COUNT) - ] - ) - ) + for chunks in divide_to_chunks( + small_annotations, self._config.MAX_COROUTINE_COUNT + ): + tasks = [] + for chunk in chunks: + tasks.append(self.get_small_annotations([i["id"] for i in chunk])) + annotations.extend( + list(itertools.chain.from_iterable(await asyncio.gather(*tasks))) ) - ) + return list(filter(None, annotations)) def execute(self): @@ -1527,7 +1516,6 @@ def execute(self): items = get_or_raise(self._service_provider.items.list(condition)) else: items = [] - id_item_map = {i.id: i for i in items} if not items: logger.info("No annotations to download.") self._response.data = [] @@ -1537,21 +1525,22 @@ def execute(self): f"Getting {items_count} annotations from " f"{self._project.name}{f'/{self._folder.name}' if self._folder.name != 'root' else ''}." ) + id_item_map = {i.id: i for i in items} self.reporter.start_progress( items_count, disable=logger.level > logging.INFO or self.reporter.log_enabled, ) - - sort_response = self._service_provider.annotations.sort_items_by_size( - project=self._project, folder=self._folder, item_ids=list(id_item_map) + sort_response = self._service_provider.annotations.get_upload_chunks( + project=self._project, + item_ids=list(id_item_map), ) large_item_ids = set(map(itemgetter("id"), sort_response["large"])) - small_items_ids = set(map(itemgetter("id"), sort_response["small"])) - large_items = list(filter(lambda item: item.id in large_item_ids, items)) - small_items = list(filter(lambda item: item.id in small_items_ids, items)) + large_items: List[BaseItemEntity] = list( + filter(lambda item: item.id in large_item_ids, items) + ) + small_items: List[List[dict]] = sort_response["small"] try: - nest_asyncio.apply() - annotations = asyncio.run(self.run_workers(large_items, small_items)) + annotations = run_async(self.run_workers(large_items, small_items)) except Exception as e: logger.error(e) self._response.errors = AppException("Can't get annotations.") @@ -1585,7 +1574,6 @@ def __init__( self._service_provider = service_provider self._callback = callback self._big_file_queue = None - self._small_file_queue = None def validate_item_names(self): if self._item_names: @@ -1664,28 +1652,24 @@ async def download_big_annotations(self, export_path): self._big_file_queue.put_nowait(None) break - async def download_small_annotations(self, export_path, folder: FolderEntity): + async def download_small_annotations( + self, item_ids: List[int], export_path, folder: FolderEntity + ): postfix = self.get_postfix() - while True: - items = await self._small_file_queue.get() - if items: - await self._service_provider.annotations.download_small_annotations( - project=self._project, - folder=folder, - item_ids=[i.id for i in items], - reporter=self.reporter, - download_path=f"{export_path}{'/' + self._folder.name if not self._folder.is_root else ''}", - postfix=postfix, - callback=self._callback, - ) - else: - self._small_file_queue.put_nowait(None) - break + await self._service_provider.annotations.download_small_annotations( + project=self._project, + folder=folder, + item_ids=item_ids, + reporter=self.reporter, + download_path=f"{export_path}{'/' + self._folder.name if not self._folder.is_root else ''}", + postfix=postfix, + callback=self._callback, + ) async def run_workers( self, big_annotations: List[BaseItemEntity], - small_annotations: List[BaseItemEntity], + small_annotations: List[List[dict]], folder: FolderEntity, export_path, ): @@ -1702,19 +1686,17 @@ async def run_workers( ) if small_annotations: - self._small_file_queue = asyncio.Queue() - small_chunks = divide_to_chunks( - small_annotations, size=self._config.ANNOTATION_CHUNK_SIZE - ) - for chunk in small_chunks: - self._small_file_queue.put_nowait(chunk) - self._small_file_queue.put_nowait(None) - await asyncio.gather( - *[ - self.download_small_annotations(export_path, folder) - for _ in range(self._config.MAX_COROUTINE_COUNT) - ] - ) + for chunks in divide_to_chunks( + small_annotations, self._config.MAX_COROUTINE_COUNT + ): + tasks = [] + for chunk in chunks: + tasks.append( + self.download_small_annotations( + [i["id"] for i in chunk], export_path, folder + ) + ) + await asyncio.gather(*tasks) def execute(self): if self.is_valid(): @@ -1735,7 +1717,6 @@ def execute(self): ).data if not folders: folders.append(self._folder) - nest_asyncio.apply() for folder in folders: if self._item_names: items = get_or_raise( @@ -1755,21 +1736,17 @@ def execute(self): new_export_path += f"/{folder.name}" id_item_map = {i.id: i for i in items} - sort_response = self._service_provider.annotations.sort_items_by_size( + sort_response = self._service_provider.annotations.get_upload_chunks( project=self._project, - folder=self._folder, item_ids=list(id_item_map), ) large_item_ids = set(map(itemgetter("id"), sort_response["large"])) - small_items_ids = set(map(itemgetter("id"), sort_response["small"])) - large_items = list( + large_items: List[BaseItemEntity] = list( filter(lambda item: item.id in large_item_ids, items) ) - small_items = list( - filter(lambda item: item.id in small_items_ids, items) - ) + small_items: List[List[dict]] = sort_response["small"] try: - asyncio.run( + run_async( self.run_workers( large_items, small_items, folder, new_export_path ) diff --git a/src/superannotate/lib/core/usecases/classes.py b/src/superannotate/lib/core/usecases/classes.py index 9ab99bee6..832852a3d 100644 --- a/src/superannotate/lib/core/usecases/classes.py +++ b/src/superannotate/lib/core/usecases/classes.py @@ -6,6 +6,7 @@ from lib.core.conditions import CONDITION_EQ as EQ from lib.core.entities import AnnotationClassEntity from lib.core.entities import ProjectEntity +from lib.core.entities.classes import GroupTypeEnum from lib.core.enums import ProjectType from lib.core.exceptions import AppException from lib.core.serviceproviders import BaseServiceProvider @@ -66,6 +67,13 @@ def validate_project_type(self): "Predefined tagging functionality is not supported for projects" f" of type {ProjectType.get_name(self._project.type)}." ) + if self._project.type != ProjectType.VECTOR: + for g in self._annotation_class.attribute_groups: + if g.group_type == GroupTypeEnum.OCR: + raise AppException( + f"OCR attribute group is not supported for project type " + f"{ProjectType.get_name(self._project.type)}." + ) def validate_default_value(self): if self._project.type == ProjectType.PIXEL.value and any( @@ -109,13 +117,19 @@ def __init__( self._annotation_classes = annotation_classes def validate_project_type(self): - if self._project.type == ProjectType.PIXEL and any( - [True for i in self._annotation_classes if i.type == "tag"] - ): - raise AppException( - f"Predefined tagging functionality is not supported" - f" for projects of type {ProjectType.get_name(self._project.type)}." - ) + if self._project.type != ProjectType.VECTOR: + for c in self._annotation_classes: + if self._project.type == ProjectType.PIXEL and c.type == "tag": + raise AppException( + f"Predefined tagging functionality is not supported" + f" for projects of type {ProjectType.get_name(self._project.type)}." + ) + for g in c.attribute_groups: + if g.group_type == GroupTypeEnum.OCR: + raise AppException( + f"OCR attribute group is not supported for project type " + f"{ProjectType.get_name(self._project.type)}." + ) def validate_default_value(self): if self._project.type == ProjectType.PIXEL.value: diff --git a/src/superannotate/lib/core/usecases/folders.py b/src/superannotate/lib/core/usecases/folders.py index 2a07a323b..ac256fabf 100644 --- a/src/superannotate/lib/core/usecases/folders.py +++ b/src/superannotate/lib/core/usecases/folders.py @@ -197,7 +197,7 @@ def execute(self): self._project, self._folder ) if not response.ok: - self._response.errors = AppException("Couldn't rename folder.") + self._response.errors = AppException(response.error) self._response.data = response.data return self._response diff --git a/src/superannotate/lib/core/usecases/projects.py b/src/superannotate/lib/core/usecases/projects.py index 4fc1991b9..e671e32cd 100644 --- a/src/superannotate/lib/core/usecases/projects.py +++ b/src/superannotate/lib/core/usecases/projects.py @@ -359,7 +359,7 @@ def validate_project_name(self): response = self._service_provider.projects.list(condition) if response.ok: for project in response.data: - if project.name == self._project.name: + if project.name == self._project.name and project != self._project: logger.error("There are duplicated names.") raise AppValidationException( f"Project name {self._project.name} is not unique. " diff --git a/src/superannotate/lib/infrastructure/services/annotation.py b/src/superannotate/lib/infrastructure/services/annotation.py index c984d20d8..a1cb39473 100644 --- a/src/superannotate/lib/infrastructure/services/annotation.py +++ b/src/superannotate/lib/infrastructure/services/annotation.py @@ -24,7 +24,7 @@ class AnnotationService(BaseAnnotationService): - ASSETS_PROVIDER_VERSION = "v2" + ASSETS_PROVIDER_VERSION = "v2.01" DEFAULT_CHUNK_SIZE = 5000 URL_GET_ANNOTATIONS = "items/annotations/download" @@ -153,33 +153,24 @@ async def list_small_annotations( params=query_params, ) - def sort_items_by_size( + def get_upload_chunks( self, project: entities.ProjectEntity, - folder: entities.FolderEntity, item_ids: List[int], ) -> Dict[str, List]: - chunk_size = 2000 - query_params = { - "project_id": project.id, - "folder_id": folder.id, - } - response_data = {"small": [], "large": []} - for i in range(0, len(item_ids), chunk_size): - body = { - "item_ids": item_ids[i : i + chunk_size], # noqa - } # noqa - response = self.client.request( - url=urljoin(self.assets_provider_url, self.URL_CLASSIFY_ITEM_SIZE), - method="POST", - params=query_params, - data=body, - ) - if not response.ok: - raise AppException(response.error) - response_data["small"].extend(response.data.get("small", [])) - response_data["large"].extend(response.data.get("large", [])) + response = self.client.request( + url=urljoin(self.assets_provider_url, self.URL_CLASSIFY_ITEM_SIZE), + method="POST", + params={"project_id": project.id, "limit": len(item_ids)}, + data={"item_ids": item_ids}, + ) + if not response.ok: + raise AppException(response.error) + response_data["small"] = [ + i["data"] for i in response.data.get("small", {}).values() + ] + response_data["large"] = response.data.get("large", []) return response_data async def download_big_annotation( diff --git a/src/superannotate/lib/infrastructure/validators.py b/src/superannotate/lib/infrastructure/validators.py index 1b3f8d862..351ea0367 100644 --- a/src/superannotate/lib/infrastructure/validators.py +++ b/src/superannotate/lib/infrastructure/validators.py @@ -27,7 +27,7 @@ def make_literal_validator( def literal_validator(v: typing.Any) -> typing.Any: try: return allowed_choices[v.lower()] - except KeyError: + except (KeyError, AttributeError): raise WrongConstantError(given=v, permitted=permitted_choices) return literal_validator diff --git a/tests/integration/annotations/test_annotation_class_new.py b/tests/integration/annotations/test_annotation_class_new.py deleted file mode 100644 index 7bcea9400..000000000 --- a/tests/integration/annotations/test_annotation_class_new.py +++ /dev/null @@ -1,45 +0,0 @@ -import os -from pathlib import Path - -from src.superannotate import SAClient -from tests.integration.base import BaseTestCase - -sa = SAClient() - - -class TestAnnotationClasses(BaseTestCase): - PROJECT_NAME = "test_annotation_class_new" - PROJECT_DESCRIPTION = "desc" - PROJECT_TYPE = "Vector" - - @property - def classes_json(self): - return os.path.join( - Path(__file__).parent.parent.parent, - "data_set/sample_project_vector/classes/classes.json", - ) - - def test_create_annotation_class(self): - sa.create_annotation_class(self.PROJECT_NAME, "tt", "#FFFFFF") - classes = sa.search_annotation_classes(self.PROJECT_NAME) - self.assertEqual(len(classes), 1) - self.assertEqual(classes[0]["type"], "object") - - def test_annotation_classes_filter(self): - sa.create_annotation_class(self.PROJECT_NAME, "tt", "#FFFFFF") - sa.create_annotation_class(self.PROJECT_NAME, "tb", "#FFFFFF") - classes = sa.search_annotation_classes(self.PROJECT_NAME, "bb") - self.assertEqual(len(classes), 0) - classes = sa.search_annotation_classes(self.PROJECT_NAME, "tt") - self.assertEqual(len(classes), 1) - - def test_create_annotation_class_from_json(self): - sa.create_annotation_classes_from_classes_json( - self.PROJECT_NAME, self.classes_json - ) - self.assertEqual(len(sa.search_annotation_classes(self.PROJECT_NAME)), 4) - - sa.create_annotation_classes_from_classes_json( - self.PROJECT_NAME, self.classes_json - ) - self.assertEqual(len(sa.search_annotation_classes(self.PROJECT_NAME)), 4) diff --git a/tests/integration/annotations/test_annotation_classes.py b/tests/integration/annotations/test_annotation_classes.py deleted file mode 100644 index 271427c23..000000000 --- a/tests/integration/annotations/test_annotation_classes.py +++ /dev/null @@ -1,52 +0,0 @@ -import os -from pathlib import Path -from urllib.parse import urlparse - -from src.superannotate import SAClient -from tests.integration.base import BaseTestCase - -sa = SAClient() - - -class TestAnnotationClasses(BaseTestCase): - PROJECT_NAME_ = "TestAnnotationClasses" - PROJECT_DESCRIPTION = "desc" - PROJECT_TYPE = "Vector" - CLASSES_JON_PATH = "data_set/invalid_json/classes.json" - - @property - def classes_path(self): - return os.path.join(Path(__file__).parent.parent.parent, self.CLASSES_JON_PATH) - - def test_invalid_json(self): - try: - sa.create_annotation_classes_from_classes_json( - self.PROJECT_NAME, self.classes_path - ) - except Exception as e: - self.assertIn("Couldn't validate annotation classes", str(e)) - - def test_annotation_classes(self): - annotation_classes = sa.search_annotation_classes(self.PROJECT_NAME) - self.assertEqual(len(annotation_classes), 0) - sa.create_annotation_class(self.PROJECT_NAME, "fff", "#FFFFFF") - annotation_classes = sa.search_annotation_classes(self.PROJECT_NAME) - self.assertEqual(len(annotation_classes), 1) - - annotation_class = sa.search_annotation_classes(self.PROJECT_NAME, "ff")[0] - sa.delete_annotation_class(self.PROJECT_NAME, annotation_class) - annotation_classes = sa.search_annotation_classes(self.PROJECT_NAME) - self.assertEqual(len(annotation_classes), 0) - - def test_annotation_classes_from_s3(self): - annotation_classes = sa.search_annotation_classes(self.PROJECT_NAME) - self.assertEqual(len(annotation_classes), 0) - f = urlparse("s3://superannotate-python-sdk-test/sample_project_pixel") - - sa.create_annotation_classes_from_classes_json( - self.PROJECT_NAME, - f.path[1:] + "/classes/classes.json", - from_s3_bucket=f.netloc, - ) - annotation_classes = sa.search_annotation_classes(self.PROJECT_NAME) - self.assertEqual(len(annotation_classes), 5) diff --git a/tests/integration/annotations/test_get_annotations.py b/tests/integration/annotations/test_get_annotations.py index 1267c90f3..7c4d2b34d 100644 --- a/tests/integration/annotations/test_get_annotations.py +++ b/tests/integration/annotations/test_get_annotations.py @@ -119,7 +119,7 @@ def test_get_annotations10000(self): [ {"name": f"example_image_{i}.jpg", "url": f"url_{i}"} for i in range(count) - ], # noqa + ], ) assert len(sa.search_items(self.PROJECT_NAME)) == count a = sa.get_annotations(self.PROJECT_NAME) diff --git a/tests/integration/classes/test_create_annotation_class.py b/tests/integration/classes/test_create_annotation_class.py index 6463ce5af..52e344be8 100644 --- a/tests/integration/classes/test_create_annotation_class.py +++ b/tests/integration/classes/test_create_annotation_class.py @@ -4,31 +4,39 @@ import pytest from src.superannotate import AppException from src.superannotate import SAClient +from src.superannotate.lib.core.entities.classes import AnnotationClassEntity from tests import DATA_SET_PATH from tests.integration.base import BaseTestCase sa = SAClient() -class TestCreateAnnotationClass(BaseTestCase): - PROJECT_NAME = "TestCreateAnnotationClass" +class TestVectorAnnotationClasses(BaseTestCase): + PROJECT_NAME = "TestVectorAnnotationClasses" + PROJECT_DESCRIPTION = "desc" PROJECT_TYPE = "Vector" - PROJECT_DESCRIPTION = "Example " - TEST_LARGE_CLASSES_JSON = "large_classes_json.json" - EXAMPLE_IMAGE_1 = "example_image_1.jpg" - @property - def large_json_path(self): - return os.path.join(DATA_SET_PATH, self.TEST_LARGE_CLASSES_JSON) + def test_create_annotation_class_search(self): + sa.create_annotation_class(self.PROJECT_NAME, "tt", "#FFFFFF") + classes = sa.search_annotation_classes(self.PROJECT_NAME) + self.assertEqual(len(classes), 1) + self.assertEqual(classes[0]["type"], "object") + self.assertEqual(classes[0]["color"], "#FFFFFF") + sa.create_annotation_class(self.PROJECT_NAME, "tb", "#FFFFFF") + # test search + classes = sa.search_annotation_classes(self.PROJECT_NAME, "bb") + self.assertEqual(len(classes), 0) + classes = sa.search_annotation_classes(self.PROJECT_NAME, "tt") + self.assertEqual(len(classes), 1) - def test_create_annotation_class(self): + def test_create_tag_annotation_class(self): sa.create_annotation_class( self.PROJECT_NAME, "test_add", "#FF0000", class_type="tag" ) classes = sa.search_annotation_classes(self.PROJECT_NAME) self.assertEqual(classes[0]["type"], "tag") - def test_create_annotation_class_with_attr(self): + def test_create_annotation_class_with_attr_and_default_value(self): _class = sa.create_annotation_class( self.PROJECT_NAME, "test_add", @@ -37,26 +45,17 @@ def test_create_annotation_class_with_attr(self): { "name": "test", "attributes": [{"name": "Car"}, {"name": "Track"}, {"name": "Bus"}], + "default_value": "Bus", } ], ) assert "is_multiselect" not in _class["attribute_groups"][0] classes = sa.search_annotation_classes(self.PROJECT_NAME) assert "is_multiselect" not in classes[0]["attribute_groups"][0] + assert classes[0]["attribute_groups"][0]["default_value"] == "Bus" - def test_create_annotations_classes_from_class_json(self): - classes = sa.create_annotation_classes_from_classes_json( - self.PROJECT_NAME, self.large_json_path - ) - self.assertEqual(len(classes), 1500) - assert "is_multiselect" not in str(classes) - - def test_hex_color_adding(self): - sa.create_annotation_class(self.PROJECT_NAME, "test_add", color="#0000FF") - classes = sa.search_annotation_classes(self.PROJECT_NAME, "test_add") - assert classes[0]["color"] == "#0000FF" - - def test_create_annotation_class_with_default_attribute(self): + @pytest.mark.flaky(reruns=2) + def test_multi_select_to_checklist(self): sa.create_annotation_class( self.PROJECT_NAME, "test_add", @@ -65,43 +64,48 @@ def test_create_annotation_class_with_default_attribute(self): attribute_groups=[ { "name": "test", + "is_multiselect": 1, "attributes": [{"name": "Car"}, {"name": "Track"}, {"name": "Bus"}], - "default_value": "Bus", } ], ) classes = sa.search_annotation_classes(self.PROJECT_NAME) - assert classes[0]["attribute_groups"][0]["default_value"] == "Bus" + assert classes[0]["attribute_groups"][0]["group_type"] == "checklist" + assert classes[0]["attribute_groups"][0]["default_value"] == [] - def test_create_annotation_classes_with_default_attribute(self): - sa.create_annotation_classes_from_classes_json( - self.PROJECT_NAME, - classes_json=[ - { - "name": "Personal vehicle", - "color": "#ecb65f", - "count": 25, - "createdAt": "2020-10-12T11:35:20.000Z", - "updatedAt": "2020-10-12T11:48:19.000Z", - "attribute_groups": [ - { - "name": "test", - "attributes": [ - {"name": "Car"}, - {"name": "Track"}, - {"name": "Bus"}, - ], - "default_value": "Bus", - } - ], - } - ], + @pytest.mark.skip(reason="Need to adjust") + def test_create_annotation_class_video_error(self): + msg = "" + try: + sa.create_annotation_class( + self.PROJECT_NAME, "test_add", "#FF0000", class_type="tag" + ) + except Exception as e: + msg = str(e) + self.assertEqual( + msg, + "Predefined tagging functionality is not supported for projects of type Video.", ) - classes = sa.search_annotation_classes(self.PROJECT_NAME) - assert classes[0]["attribute_groups"][0]["default_value"] == "Bus" + + def test_create_radio_annotation_class_attr_required(self): + msg = "" + try: + sa.create_annotation_class( + self.PROJECT_NAME, + "test_add", + "#FF0000", + attribute_groups=[ + { + "group_type": "radio", + "name": "name", + } + ], + ) + except Exception as e: + msg = str(e) + self.assertEqual(msg, '"classes[0].attribute_groups[0].attributes" is required') def test_create_annotation_class_backend_errors(self): - from lib.core.entities.classes import AnnotationClassEntity response = sa.controller.annotation_classes.create( sa.controller.projects.get_by_name(self.PROJECT_NAME).data, @@ -156,47 +160,141 @@ def test_create_annotation_classes_with_empty_default_attribute(self): assert classes[0]["attribute_groups"][0]["default_value"] is None assert "is_multiselect" not in classes[0]["attribute_groups"][0] + def test_class_creation_type(self): + with tempfile.TemporaryDirectory() as tmpdir_name: + temp_path = f"{tmpdir_name}/new_classes.json" + with open(temp_path, "w") as new_classes: + new_classes.write( + """ + [ + { + "id":56820, + "project_id":7617, + "name":"Personal vehicle", + "color":"#547497", + "count":18, + "createdAt":"2020-09-29T10:39:39.000Z", + "updatedAt":"2020-09-29T10:48:18.000Z", + "type": "tag", + "attribute_groups":[ + { + "id":21448, + "class_id":56820, + "name":"Large", + "is_multiselect":0, + "createdAt":"2020-09-29T10:39:39.000Z", + "updatedAt":"2020-09-29T10:39:39.000Z", + "attributes":[ + { + "id":57096, + "group_id":21448, + "project_id":7617, + "name":"no", + "count":0, + "createdAt":"2020-09-29T10:39:39.000Z", + "updatedAt":"2020-09-29T10:39:39.000Z" + }, + { + "id":57097, + "group_id":21448, + "project_id":7617, + "name":"yes", + "count":1, + "createdAt":"2020-09-29T10:39:39.000Z", + "updatedAt":"2020-09-29T10:48:18.000Z" + } + ] + } + ] + }, + { + "id":56821, + "project_id":7617, + "name":"Large vehicle", + "color":"#2ba36d", + "count":1, + "createdAt":"2020-09-29T10:39:39.000Z", + "updatedAt":"2020-09-29T10:48:18.000Z", + "attribute_groups":[ + { + "id":21449, + "class_id":56821, + "name":"small", + "is_multiselect":0, + "createdAt":"2020-09-29T10:39:39.000Z", + "updatedAt":"2020-09-29T10:39:39.000Z", + "attributes":[ + { + "id":57098, + "group_id":21449, + "project_id":7617, + "name":"yes", + "count":0, + "createdAt":"2020-09-29T10:39:39.000Z", + "updatedAt":"2020-09-29T10:39:39.000Z" + }, + { + "id":57099, + "group_id":21449, + "project_id":7617, + "name":"no", + "count":1, + "createdAt":"2020-09-29T10:39:39.000Z", + "updatedAt":"2020-09-29T10:48:18.000Z" + } + ] + } + ] + }, + { + "id":56822, + "project_id":7617, + "name":"Pedestrian", + "color":"#d4da03", + "count":3, + "createdAt":"2020-09-29T10:39:39.000Z", + "updatedAt":"2020-09-29T10:48:18.000Z", + "attribute_groups":[ + + ] + }, + { + "id":56823, + "project_id":7617, + "name":"Two wheeled vehicle", + "color":"#f11aec", + "count":1, + "createdAt":"2020-09-29T10:39:39.000Z", + "updatedAt":"2020-09-29T10:48:18.000Z", + "attribute_groups":[ -class TestCreateAnnotationClassNonVectorWithError(BaseTestCase): - PROJECT_NAME = "TestCreateAnnotationClassNonVectorWithError" - PROJECT_TYPE = "Video" - PROJECT_DESCRIPTION = "Example Project test pixel basic images" + ] + }, + { + "id":56824, + "project_id":7617, + "name":"Traffic sign", + "color":"#d8a7fd", + "count":9, + "createdAt":"2020-09-29T10:39:39.000Z", + "updatedAt":"2020-09-29T10:48:18.000Z", + "attribute_groups":[ - @pytest.mark.skip(reason="Need to adjust") - def test_create_annotation_class(self): - msg = "" - try: - sa.create_annotation_class( - self.PROJECT_NAME, "test_add", "#FF0000", class_type="tag" - ) - except Exception as e: - msg = str(e) - self.assertEqual( - msg, - "Predefined tagging functionality is not supported for projects of type Video.", - ) + ] + } + ] - def test_create_radio_annotation_class_attr_required(self): - msg = "" - try: - sa.create_annotation_class( - self.PROJECT_NAME, - "test_add", - "#FF0000", - attribute_groups=[ - { - "group_type": "radio", - "name": "name", - } - ], + """ + ) + + created = sa.create_annotation_classes_from_classes_json( + self.PROJECT_NAME, temp_path ) - except Exception as e: - msg = str(e) - self.assertEqual(msg, '"classes[0].attribute_groups[0].attributes" is required') + self.assertEqual({i["type"] for i in created}, {"tag", "object"}) -class TestCreateAnnotationClassesNonVectorWithError(BaseTestCase): - PROJECT_NAME = "TestCreateAnnotationClassesNonVectorWithError" +class TestVideoCreateAnnotationClasses(BaseTestCase): + PROJECT_NAME = "TestVideoCreateAnnotationClasses" PROJECT_TYPE = "Video" PROJECT_DESCRIPTION = "Example Project test pixel basic images" @@ -245,8 +343,33 @@ def test_create_annotation_class(self): "Predefined tagging functionality is not supported for projects of type Video.", ) + def test_create_annotation_class_via_ocr_group_type(self): + with self.assertRaisesRegexp( + AppException, + f"OCR attribute group is not supported for project type {self.PROJECT_TYPE}.", + ): + attribute_groups = [ + { + "id": 21448, + "class_id": 56820, + "name": "Large", + "group_type": "ocr", + "is_multiselect": 0, + "createdAt": "2020-09-29T10:39:39.000Z", + "updatedAt": "2020-09-29T10:39:39.000Z", + "attributes": [], + } + ] + sa.create_annotation_class( + self.PROJECT_NAME, + "test_add", + "#FF0000", + attribute_groups, # noqa + class_type="tag", + ) + -class TestCreateAnnotationClassPixel(BaseTestCase): +class TestPixelCreateAnnotationClass(BaseTestCase): PROJECT_NAME = "TestCreateAnnotationClassPixel" PROJECT_TYPE = "Pixel" PROJECT_DESCRIPTION = "Example " @@ -277,32 +400,3 @@ def test_create_annotation_class_with_default_attribute(self): } ], ) - - def test_create_annotation_classes_with_default_attribute(self): - with self.assertRaisesRegexp( - AppException, - 'The "default_value" key is not supported for project type Pixel.', - ): - sa.create_annotation_classes_from_classes_json( - self.PROJECT_NAME, - classes_json=[ - { - "name": "Personal vehicle", - "color": "#ecb65f", - "count": 25, - "createdAt": "2020-10-12T11:35:20.000Z", - "updatedAt": "2020-10-12T11:48:19.000Z", - "attribute_groups": [ - { - "name": "test", - "attributes": [ - {"name": "Car"}, - {"name": "Track"}, - {"name": "Bus"}, - ], - "default_value": "Bus", - } - ], - } - ], - ) diff --git a/tests/integration/classes/test_create_annotation_classes_from_classes_json.py b/tests/integration/classes/test_create_annotation_classes_from_classes_json.py new file mode 100644 index 000000000..053c18503 --- /dev/null +++ b/tests/integration/classes/test_create_annotation_classes_from_classes_json.py @@ -0,0 +1,208 @@ +import os +import tempfile +from pathlib import Path +from urllib.parse import urlparse + +import pytest +from src.superannotate import AppException +from src.superannotate import SAClient +from tests import DATA_SET_PATH +from tests.integration.base import BaseTestCase + +sa = SAClient() + + +class TestVectorCreateAnnotationClass(BaseTestCase): + PROJECT_NAME = "TestCreateAnnotationClass" + PROJECT_TYPE = "Vector" + PROJECT_DESCRIPTION = "Example " + TEST_LARGE_CLASSES_JSON = "large_classes_json.json" + EXAMPLE_IMAGE_1 = "example_image_1.jpg" + INVALID_CLASSES_JON_PATH = "data_set/invalid_json/classes.json" + + @property + def large_json_path(self): + return os.path.join(DATA_SET_PATH, self.TEST_LARGE_CLASSES_JSON) + + @property + def invalid_classes_path(self): + return os.path.join( + Path(__file__).parent.parent.parent, self.INVALID_CLASSES_JON_PATH + ) + + @property + def classes_json(self): + return os.path.join( + Path(__file__).parent.parent.parent, + "data_set/sample_project_vector/classes/classes.json", + ) + + def test_create_annotation_class_from_json(self): + sa.create_annotation_classes_from_classes_json( + self.PROJECT_NAME, self.classes_json + ) + self.assertEqual(len(sa.search_annotation_classes(self.PROJECT_NAME)), 4) + + sa.create_annotation_classes_from_classes_json( + self.PROJECT_NAME, self.classes_json + ) + self.assertEqual(len(sa.search_annotation_classes(self.PROJECT_NAME)), 4) + + def test_invalid_json(self): + try: + sa.create_annotation_classes_from_classes_json( + self.PROJECT_NAME, self.invalid_classes_path + ) + except Exception as e: + self.assertIn("Couldn't validate annotation classes", str(e)) + + def test_create_annotations_classes_is_multiselect(self): + classes = sa.create_annotation_classes_from_classes_json( + self.PROJECT_NAME, self.large_json_path + ) + self.assertEqual(len(classes), 1500) + assert "is_multiselect" not in str(classes) + + def test_create_annotation_classes_from_s3(self): + annotation_classes = sa.search_annotation_classes(self.PROJECT_NAME) + self.assertEqual(len(annotation_classes), 0) + f = urlparse("s3://superannotate-python-sdk-test/sample_project_pixel") + + sa.create_annotation_classes_from_classes_json( + self.PROJECT_NAME, + f.path[1:] + "/classes/classes.json", + from_s3_bucket=f.netloc, + ) + annotation_classes = sa.search_annotation_classes(self.PROJECT_NAME) + self.assertEqual(len(annotation_classes), 5) + + +class TestVideoCreateAnnotationClasses(BaseTestCase): + PROJECT_NAME = "TestVideoCreateAnnotationClasses" + PROJECT_TYPE = "Video" + PROJECT_DESCRIPTION = "Example Project test pixel basic images" + + @pytest.mark.skip(reason="Need to adjust") + def test_create_annotation_class(self): + with tempfile.TemporaryDirectory() as tmpdir_name: + temp_path = f"{tmpdir_name}/new_classes.json" + with open(temp_path, "w") as new_classes: + new_classes.write( + """ + [ + { + "id":56820, + "project_id":7617, + "name":"Personal vehicle", + "color":"#547497", + "count":18, + "createdAt":"2020-09-29T10:39:39.000Z", + "updatedAt":"2020-09-29T10:48:18.000Z", + "type": "tag", + "attribute_groups":[ + { + "id":21448, + "class_id":56820, + "name":"Large", + "is_multiselect":0, + "createdAt":"2020-09-29T10:39:39.000Z", + "updatedAt":"2020-09-29T10:39:39.000Z", + "attributes":[] + } + ] + } + ] + + """ + ) + msg = "" + try: + sa.create_annotation_classes_from_classes_json( + self.PROJECT_NAME, temp_path + ) + except Exception as e: + msg = str(e) + self.assertEqual( + msg, + "Predefined tagging functionality is not supported for projects of type Video.", + ) + + def test_create_annotation_class_via_json_and_ocr_group_type(self): + with tempfile.TemporaryDirectory() as tmpdir_name: + temp_path = f"{tmpdir_name}/new_classes.json" + with open(temp_path, "w") as new_classes: + new_classes.write( + """ + [ + { + "id":56820, + "project_id":7617, + "name":"Personal vehicle", + "color":"#547497", + "count":18, + "createdAt":"2020-09-29T10:39:39.000Z", + "updatedAt":"2020-09-29T10:48:18.000Z", + "type": "tag", + "attribute_groups":[ + { + "id":21448, + "class_id":56820, + "name":"Large", + "group_type": "ocr", + "is_multiselect":0, + "createdAt":"2020-09-29T10:39:39.000Z", + "updatedAt":"2020-09-29T10:39:39.000Z", + "attributes":[] + } + ] + } + ] + """ + ) + with self.assertRaisesRegexp( + AppException, + f"OCR attribute group is not supported for project type {self.PROJECT_TYPE}.", + ): + sa.create_annotation_classes_from_classes_json( + self.PROJECT_NAME, temp_path + ) + + +class TestPixelCreateAnnotationClass(BaseTestCase): + PROJECT_NAME = "TestCreateAnnotationClassPixel" + PROJECT_TYPE = "Pixel" + PROJECT_DESCRIPTION = "Example " + TEST_LARGE_CLASSES_JSON = "large_classes_json.json" + + @property + def large_json_path(self): + return os.path.join(DATA_SET_PATH, self.TEST_LARGE_CLASSES_JSON) + + def test_create_annotation_classes_with_default_attribute(self): + with self.assertRaisesRegexp( + AppException, + 'The "default_value" key is not supported for project type Pixel.', + ): + sa.create_annotation_classes_from_classes_json( + self.PROJECT_NAME, + classes_json=[ + { + "name": "Personal vehicle", + "color": "#ecb65f", + "count": 25, + "createdAt": "2020-10-12T11:35:20.000Z", + "updatedAt": "2020-10-12T11:48:19.000Z", + "attribute_groups": [ + { + "name": "test", + "attributes": [ + {"name": "Car"}, + {"name": "Track"}, + {"name": "Bus"}, + ], + "default_value": "Bus", + } + ], + } + ], + ) diff --git a/tests/integration/classes/test_create_bed_handling.py b/tests/integration/classes/test_create_bed_handling.py deleted file mode 100644 index 926eb872b..000000000 --- a/tests/integration/classes/test_create_bed_handling.py +++ /dev/null @@ -1,32 +0,0 @@ -import pytest -from src.superannotate import SAClient -from tests.integration.base import BaseTestCase - -sa = SAClient() - - -class TestCreateAnnotationClass(BaseTestCase): - PROJECT_NAME = "TestCreateAnnotationClassBED" - PROJECT_TYPE = "Vector" - PROJECT_DESCRIPTION = "Example " - TEST_LARGE_CLASSES_JSON = "large_classes_json.json" - EXAMPLE_IMAGE_1 = "example_image_1.jpg" - - @pytest.mark.flaky(reruns=2) - def test_multi_select_to_checklist(self): - sa.create_annotation_class( - self.PROJECT_NAME, - "test_add", - "#FF0000", - class_type="tag", - attribute_groups=[ - { - "name": "test", - "is_multiselect": 1, - "attributes": [{"name": "Car"}, {"name": "Track"}, {"name": "Bus"}], - } - ], - ) - classes = sa.search_annotation_classes(self.PROJECT_NAME) - assert classes[0]["attribute_groups"][0]["group_type"] == "checklist" - assert classes[0]["attribute_groups"][0]["default_value"] == [] diff --git a/tests/integration/classes/test_tag_annotation_classes.py b/tests/integration/classes/test_tag_annotation_classes.py deleted file mode 100644 index 967383f48..000000000 --- a/tests/integration/classes/test_tag_annotation_classes.py +++ /dev/null @@ -1,146 +0,0 @@ -import tempfile - -from src.superannotate import SAClient -from tests.integration.base import BaseTestCase - -sa = SAClient() - - -class TestTagClasses(BaseTestCase): - PROJECT_NAME = "sample_project_pixel" - PROJECT_TYPE = "Vector" - PROJECT_DESCRIPTION = "Example Project test pixel basic images" - TEST_FOLDER_PTH = "data_set/sample_project_pixel" - EXAMPLE_IMAGE_1 = "example_image_1.jpg" - - def test_class_creation_type(self): - with tempfile.TemporaryDirectory() as tmpdir_name: - temp_path = f"{tmpdir_name}/new_classes.json" - with open(temp_path, "w") as new_classes: - new_classes.write( - """ - [ - { - "id":56820, - "project_id":7617, - "name":"Personal vehicle", - "color":"#547497", - "count":18, - "createdAt":"2020-09-29T10:39:39.000Z", - "updatedAt":"2020-09-29T10:48:18.000Z", - "type": "tag", - "attribute_groups":[ - { - "id":21448, - "class_id":56820, - "name":"Large", - "is_multiselect":0, - "createdAt":"2020-09-29T10:39:39.000Z", - "updatedAt":"2020-09-29T10:39:39.000Z", - "attributes":[ - { - "id":57096, - "group_id":21448, - "project_id":7617, - "name":"no", - "count":0, - "createdAt":"2020-09-29T10:39:39.000Z", - "updatedAt":"2020-09-29T10:39:39.000Z" - }, - { - "id":57097, - "group_id":21448, - "project_id":7617, - "name":"yes", - "count":1, - "createdAt":"2020-09-29T10:39:39.000Z", - "updatedAt":"2020-09-29T10:48:18.000Z" - } - ] - } - ] - }, - { - "id":56821, - "project_id":7617, - "name":"Large vehicle", - "color":"#2ba36d", - "count":1, - "createdAt":"2020-09-29T10:39:39.000Z", - "updatedAt":"2020-09-29T10:48:18.000Z", - "attribute_groups":[ - { - "id":21449, - "class_id":56821, - "name":"small", - "is_multiselect":0, - "createdAt":"2020-09-29T10:39:39.000Z", - "updatedAt":"2020-09-29T10:39:39.000Z", - "attributes":[ - { - "id":57098, - "group_id":21449, - "project_id":7617, - "name":"yes", - "count":0, - "createdAt":"2020-09-29T10:39:39.000Z", - "updatedAt":"2020-09-29T10:39:39.000Z" - }, - { - "id":57099, - "group_id":21449, - "project_id":7617, - "name":"no", - "count":1, - "createdAt":"2020-09-29T10:39:39.000Z", - "updatedAt":"2020-09-29T10:48:18.000Z" - } - ] - } - ] - }, - { - "id":56822, - "project_id":7617, - "name":"Pedestrian", - "color":"#d4da03", - "count":3, - "createdAt":"2020-09-29T10:39:39.000Z", - "updatedAt":"2020-09-29T10:48:18.000Z", - "attribute_groups":[ - - ] - }, - { - "id":56823, - "project_id":7617, - "name":"Two wheeled vehicle", - "color":"#f11aec", - "count":1, - "createdAt":"2020-09-29T10:39:39.000Z", - "updatedAt":"2020-09-29T10:48:18.000Z", - "attribute_groups":[ - - ] - }, - { - "id":56824, - "project_id":7617, - "name":"Traffic sign", - "color":"#d8a7fd", - "count":9, - "createdAt":"2020-09-29T10:39:39.000Z", - "updatedAt":"2020-09-29T10:48:18.000Z", - "attribute_groups":[ - - ] - } - ] - - """ - ) - - created = sa.create_annotation_classes_from_classes_json( - self.PROJECT_NAME, temp_path - ) - self.assertEqual({i["type"] for i in created}, {"tag", "object"}) diff --git a/tests/integration/folders/test_folders.py b/tests/integration/folders/test_folders.py index 528adbace..609fbd58c 100644 --- a/tests/integration/folders/test_folders.py +++ b/tests/integration/folders/test_folders.py @@ -291,7 +291,7 @@ def test_search_folder(self): ) assert len(folders) == 0 folders = sa.search_folders( - self.PROJECT_NAME, status="Undefined", return_metadata=True + self.PROJECT_NAME, status="OnHold", return_metadata=True ) assert len(folders) == 0 folders = sa.search_folders( diff --git a/tests/integration/folders/test_set_folder_status.py b/tests/integration/folders/test_set_folder_status.py new file mode 100644 index 000000000..267983a0b --- /dev/null +++ b/tests/integration/folders/test_set_folder_status.py @@ -0,0 +1,89 @@ +from unittest import TestCase +from unittest.mock import patch + +from src.superannotate import AppException +from src.superannotate.lib.core.service_types import ServiceResponse +from superannotate import SAClient + + +sa = SAClient() + + +class TestSetFolderStatus(TestCase): + PROJECT_NAME = "test_set_folder_status" + FOLDER_NAME = "test_folder" + PROJECT_DESCRIPTION = "desc" + PROJECT_TYPE = "Vector" + FOLDER_STATUSES = ["NotStarted", "InProgress", "Completed", "OnHold"] + + @classmethod + def setUpClass(cls, *args, **kwargs): + cls.tearDownClass() + cls._project = sa.create_project( + cls.PROJECT_NAME, cls.PROJECT_DESCRIPTION, cls.PROJECT_TYPE + ) + sa.create_folder(cls.PROJECT_NAME, cls.FOLDER_NAME) + folder = sa.get_folder_metadata( + project=cls.PROJECT_NAME, folder_name=cls.FOLDER_NAME + ) + assert folder["status"] == "NotStarted" + + @classmethod + def tearDownClass(cls) -> None: + sa.delete_project(cls.PROJECT_NAME) + + def test_set_folder_status(self): + with self.assertLogs("sa", level="INFO") as cm: + for index, status in enumerate(self.FOLDER_STATUSES): + sa.set_folder_status( + project=self.PROJECT_NAME, folder=self.FOLDER_NAME, status=status + ) + folder = sa.get_folder_metadata( + project=self.PROJECT_NAME, folder_name=self.FOLDER_NAME + ) + assert ( + f"INFO:sa:Successfully updated {self.PROJECT_NAME}/{self.FOLDER_NAME} status to {status}" + == cm.output[index] + ) + self.assertEqual(status, folder["status"]) + self.assertEqual(len(cm.output), len(self.FOLDER_STATUSES)) + + @patch("lib.infrastructure.services.folder.FolderService.update") + def test_set_folder_status_fail(self, update_function): + update_function.return_value = ServiceResponse(_error="ERROR") + with self.assertRaisesRegexp( + AppException, + f"Failed to change {self.PROJECT_NAME}/{self.FOLDER_NAME} status.", + ): + sa.set_folder_status( + project=self.PROJECT_NAME, folder=self.FOLDER_NAME, status="Completed" + ) + + def test_set_folder_status_via_invalid_status(self): + with self.assertRaisesRegexp( + AppException, + "Available values are 'NotStarted', 'InProgress', 'Completed', 'OnHold'.", + ): + sa.set_folder_status( + project=self.PROJECT_NAME, + folder=self.FOLDER_NAME, + status="InvalidStatus", + ) + + def test_set_folder_status_via_invalid_project(self): + with self.assertRaisesRegexp( + AppException, + "Project not found.", + ): + sa.set_folder_status( + project="Invalid Name", folder=self.FOLDER_NAME, status="Completed" + ) + + def test_set_folder_status_via_invalid_folder(self): + with self.assertRaisesRegexp( + AppException, + "Folder not found.", + ): + sa.set_folder_status( + project=self.PROJECT_NAME, folder="Invalid Name", status="Completed" + ) diff --git a/tests/integration/items/test_search_items.py b/tests/integration/items/test_search_items.py index d79d9b8ff..4779569fa 100644 --- a/tests/integration/items/test_search_items.py +++ b/tests/integration/items/test_search_items.py @@ -78,20 +78,3 @@ def test_search_items_recursive(self): items = sa.search_items(self.PROJECT_NAME, recursive=True) assert len(items) == 8 - - def test_search_items_by_annotator_email(self): - test_email = "shab.prog@gmail.com" - sa.add_contributors_to_project( - self.PROJECT_NAME, ["shab.prog@gmail.com"], "Annotator" - ) - sa.upload_images_from_folder_to_project( - self.PROJECT_NAME, self.folder_path, annotation_status="InProgress" - ) - sa.assign_items( - self.PROJECT_NAME, [self.IMAGE1_NAME, self.IMAGE2_NAME], test_email - ) - - items = sa.search_items( - self.PROJECT_NAME, annotator_email=test_email, recursive=True - ) - assert len(items) == 2 diff --git a/tests/integration/mixpanel/test_mixpanel_decorator.py b/tests/integration/mixpanel/test_mixpanel_decorator.py index 6f92145e6..a1c5e1abd 100644 --- a/tests/integration/mixpanel/test_mixpanel_decorator.py +++ b/tests/integration/mixpanel/test_mixpanel_decorator.py @@ -1,5 +1,8 @@ import copy +import platform +import tempfile import threading +from configparser import ConfigParser from unittest import TestCase from unittest.mock import patch @@ -18,6 +21,8 @@ class TestMixpanel(TestCase): "Team Owner": TEAM_DATA["creator_id"], "Version": __version__, "Success": True, + "Python version": platform.python_version(), + "Python interpreter type": platform.python_implementation(), } PROJECT_NAME = "TEST_MIX" PROJECT_DESCRIPTION = "Desc" @@ -49,6 +54,57 @@ def _safe_delete_project(cls, project_name): def default_payload(self): return copy.copy(self.BLANK_PAYLOAD) + @patch("lib.app.interface.base_interface.Tracker._track") + def test_init(self, track_method): + SAClient() + result = list(track_method.call_args)[0] + payload = self.default_payload + payload.update({"sa_token": "False", "config_path": "False"}) + assert result[1] == "__init__" + assert payload == result[2] + + @patch("lib.app.interface.base_interface.Tracker._track") + @patch("lib.core.usecases.GetTeamUseCase") + def test_init_via_token(self, get_team_use_case, track_method): + SAClient(token="test=3232") + result = list(track_method.call_args)[0] + payload = self.default_payload + payload.update( + { + "sa_token": "True", + "config_path": "False", + "Team": get_team_use_case().execute().data.name, + "Team Owner": get_team_use_case().execute().data.creator_id, + } + ) + assert result[1] == "__init__" + assert payload == result[2] + + @patch("lib.app.interface.base_interface.Tracker._track") + @patch("lib.core.usecases.GetTeamUseCase") + def test_init_via_config_file(self, get_team_use_case, track_method): + with tempfile.TemporaryDirectory() as config_dir: + config_ini_path = f"{config_dir}/config.ini" + with patch("lib.core.CONFIG_INI_FILE_LOCATION", config_ini_path): + with open(f"{config_dir}/config.ini", "w") as config_ini: + config_parser = ConfigParser() + config_parser.optionxform = str + config_parser["DEFAULT"] = {"SA_TOKEN": "test=3232"} + config_parser.write(config_ini) + SAClient(config_path=f"{config_dir}/config.ini") + result = list(track_method.call_args)[0] + payload = self.default_payload + payload.update( + { + "sa_token": "False", + "config_path": "True", + "Team": get_team_use_case().execute().data.name, + "Team Owner": get_team_use_case().execute().data.creator_id, + } + ) + assert result[1] == "__init__" + assert payload == result[2] + @patch("lib.app.interface.base_interface.Tracker._track") def test_get_team_metadata(self, track_method): team = self.CLIENT.get_team_metadata() @@ -57,7 +113,7 @@ def test_get_team_metadata(self, track_method): payload = self.default_payload assert result[0] == team_owner assert result[1] == "get_team_metadata" - assert payload == list(track_method.call_args)[0][2] + assert payload == result[2] @patch("lib.app.interface.base_interface.Tracker._track") def test_search_team_contributors(self, track_method): @@ -72,7 +128,7 @@ def test_search_team_contributors(self, track_method): payload = self.default_payload payload.update(kwargs) assert result[1] == "search_team_contributors" - assert payload == list(track_method.call_args)[0][2] + assert payload == result[2] @patch("lib.app.interface.base_interface.Tracker._track") def test_search_projects(self, track_method): @@ -87,7 +143,7 @@ def test_search_projects(self, track_method): payload = self.default_payload payload.update(kwargs) assert result[1] == "search_projects" - assert payload == list(track_method.call_args)[0][2] + assert payload == result[2] @patch("lib.app.interface.base_interface.Tracker._track") def test_create_project(self, track_method): @@ -110,7 +166,7 @@ def test_create_project(self, track_method): payload.update(kwargs) payload["settings"] = list(kwargs["settings"].keys()) assert result[1] == "create_project" - assert payload == list(track_method.call_args)[0][2] + assert payload == result[2] @pytest.mark.skip("Need to adjust") @patch("lib.app.interface.base_interface.Tracker._track") diff --git a/tests/integration/projects/test_set_project_status.py b/tests/integration/projects/test_set_project_status.py new file mode 100644 index 000000000..358b89eda --- /dev/null +++ b/tests/integration/projects/test_set_project_status.py @@ -0,0 +1,64 @@ +from unittest import TestCase +from unittest.mock import patch + +from src.superannotate import AppException +from src.superannotate.lib.core.service_types import ServiceResponse +from superannotate import SAClient + + +sa = SAClient() + + +class TestSetProjectStatus(TestCase): + PROJECT_NAME = "test_set_project_status" + PROJECT_DESCRIPTION = "desc" + PROJECT_TYPE = "Vector" + PROJECT_STATUSES = ["NotStarted", "InProgress", "Completed", "OnHold"] + + @classmethod + def setUpClass(cls, *args, **kwargs): + cls.tearDownClass() + cls._project = sa.create_project( + cls.PROJECT_NAME, cls.PROJECT_DESCRIPTION, cls.PROJECT_TYPE + ) + project = sa.get_project_metadata(cls.PROJECT_NAME) + assert project["status"] == "NotStarted" + + @classmethod + def tearDownClass(cls) -> None: + sa.delete_project(cls.PROJECT_NAME) + + def test_set_project_status(self): + with self.assertLogs("sa", level="INFO") as cm: + for index, status in enumerate(self.PROJECT_STATUSES): + sa.set_project_status(project=self.PROJECT_NAME, status=status) + project = sa.get_project_metadata(self.PROJECT_NAME) + assert ( + f"INFO:sa:Successfully updated {self.PROJECT_NAME} status to {status}" + == cm.output[index] + ) + self.assertEqual(status, project["status"]) + self.assertEqual(len(cm.output), len(self.PROJECT_STATUSES)) + + @patch("lib.infrastructure.services.project.ProjectService.update") + def test_set_project_status_fail(self, update_function): + update_function.return_value = ServiceResponse(_error="ERROR") + with self.assertRaisesRegexp( + AppException, + f"Failed to change {self.PROJECT_NAME} status.", + ): + sa.set_project_status(project=self.PROJECT_NAME, status="Completed") + + def test_set_project_status_via_invalid_status(self): + with self.assertRaisesRegexp( + AppException, + "Available values are 'NotStarted', 'InProgress', 'Completed', 'OnHold'.", + ): + sa.set_project_status(project=self.PROJECT_NAME, status="InvalidStatus") + + def test_set_project_status_via_invalid_project(self): + with self.assertRaisesRegexp( + AppException, + "Project not found.", + ): + sa.set_project_status(project="Invalid name", status="Completed") diff --git a/tests/unit/test_async_functions.py b/tests/unit/test_async_functions.py new file mode 100644 index 000000000..7630a8a60 --- /dev/null +++ b/tests/unit/test_async_functions.py @@ -0,0 +1,107 @@ +import asyncio +import concurrent.futures +from unittest import TestCase + +from superannotate import SAClient + +sa = SAClient() + + +class DummyIterator: + def __init__(self, delay, to): + self.delay = delay + self.i = 0 + self.to = to + + def __aiter__(self): + return self + + async def __anext__(self): + i = self.i + if i >= self.to: + raise StopAsyncIteration + self.i += 1 + if i: + await asyncio.sleep(self.delay) + return i + + +class TestAsyncFunctions(TestCase): + PROJECT_NAME = "TestAsync" + PROJECT_DESCRIPTION = "Desc" + PROJECT_TYPE = "Vector" + ATTACH_PAYLOAD = [{"name": f"name_{i}", "url": "url"} for i in range(4)] + UPLOAD_PAYLOAD = [{"metadata": {"name": f"name_{i}"}} for i in range(4)] + + @classmethod + def setUpClass(cls): + cls.tearDownClass() + cls._project = sa.create_project( + cls.PROJECT_NAME, cls.PROJECT_DESCRIPTION, cls.PROJECT_TYPE + ) + sa.attach_items(cls.PROJECT_NAME, cls.ATTACH_PAYLOAD) + + @classmethod + def tearDownClass(cls): + sa.delete_project(cls.PROJECT_NAME) + + @staticmethod + async def nested(): + annotations = sa.get_annotations(TestAsyncFunctions.PROJECT_NAME) + assert len(annotations) == 4 + + def test_get_annotations_in_running_event_loop(self): + async def _test(): + annotations = sa.get_annotations(self.PROJECT_NAME) + assert len(annotations) == 4 + + asyncio.run(_test()) + + def test_create_task_get_annotations_in_running_event_loop(self): + async def _test(): + task1 = asyncio.create_task(self.nested()) + task2 = asyncio.create_task(self.nested()) + await task1 + await task2 + + asyncio.run(_test()) + + def test_gather_get_annotations_in_running_event_loop(self): + async def gather_test(): + await asyncio.gather(self.nested(), self.nested()) + + asyncio.run(gather_test()) + + def test_gather_async_for(self): + async def gather_test(): + async for _ in DummyIterator(delay=0.01, to=2): + annotations = sa.get_annotations(TestAsyncFunctions.PROJECT_NAME) + assert len(annotations) == 4 + + asyncio.run(gather_test()) + + def test_upload_annotations_in_running_event_loop(self): + async def _test(): + annotations = sa.upload_annotations( + self.PROJECT_NAME, annotations=self.UPLOAD_PAYLOAD + ) + assert len(annotations["succeeded"]) == 4 + + asyncio.run(_test()) + + def test_upload_in_threads(self): + def _test(): + annotations = sa.upload_annotations( + self.PROJECT_NAME, annotations=self.UPLOAD_PAYLOAD + ) + assert len(annotations["succeeded"]) == 4 + return True + + with concurrent.futures.ThreadPoolExecutor(max_workers=8) as executor: + futures = [] + for i in range(8): + futures.append(executor.submit(_test)) + results = [] + for f in concurrent.futures.as_completed(futures): + results.append(f.result()) + assert all(results) diff --git a/tests/integration/classes/test_classes_serialization.py b/tests/unit/test_classes_serialization.py similarity index 99% rename from tests/integration/classes/test_classes_serialization.py rename to tests/unit/test_classes_serialization.py index 143fc484a..9ca92e956 100644 --- a/tests/integration/classes/test_classes_serialization.py +++ b/tests/unit/test_classes_serialization.py @@ -78,6 +78,6 @@ def test_group_type_wrong_arg(self): "'radio',", "'checklist',", "'numeric',", - "'text'", + "'text',", "'ocr'", ] == wrap_error(e).split() diff --git a/tests/unit/test_init.py b/tests/unit/test_init.py index dcb353d70..d882ba293 100644 --- a/tests/unit/test_init.py +++ b/tests/unit/test_init.py @@ -2,6 +2,7 @@ import os import tempfile from configparser import ConfigParser +from pathlib import Path from unittest import TestCase from unittest.mock import patch @@ -30,7 +31,6 @@ def test_init_via_token(self, get_team_use_case): @patch("lib.core.usecases.GetTeamUseCase") def test_init_via_config_json(self, get_team_use_case): with tempfile.TemporaryDirectory() as config_dir: - constants.HOME_PATH = config_dir config_ini_path = f"{config_dir}/config.ini" config_json_path = f"{config_dir}/config.json" with patch("lib.core.CONFIG_INI_FILE_LOCATION", config_ini_path), patch( @@ -48,7 +48,6 @@ def test_init_via_config_json(self, get_team_use_case): def test_init_via_config_json_invalid_json(self): with tempfile.TemporaryDirectory() as config_dir: - constants.HOME_PATH = config_dir config_ini_path = f"{config_dir}/config.ini" config_json_path = f"{config_dir}/config.json" with patch("lib.core.CONFIG_INI_FILE_LOCATION", config_ini_path), patch( @@ -65,7 +64,6 @@ def test_init_via_config_json_invalid_json(self): @patch("lib.core.usecases.GetTeamUseCase") def test_init_via_config_ini(self, get_team_use_case): with tempfile.TemporaryDirectory() as config_dir: - constants.HOME_PATH = config_dir config_ini_path = f"{config_dir}/config.ini" config_json_path = f"{config_dir}/config.json" with patch("lib.core.CONFIG_INI_FILE_LOCATION", config_ini_path), patch( @@ -88,6 +86,34 @@ def test_init_via_config_ini(self, get_team_use_case): self._token.split("=")[-1] ) + @patch("lib.core.usecases.GetTeamUseCase") + def test_init_via_config_relative_filepath(self, get_team_use_case): + with tempfile.TemporaryDirectory(dir=Path("~").expanduser()) as config_dir: + config_ini_path = f"{config_dir}/config.ini" + config_json_path = f"{config_dir}/config.json" + with patch("lib.core.CONFIG_INI_FILE_LOCATION", config_ini_path), patch( + "lib.core.CONFIG_JSON_FILE_LOCATION", config_json_path + ): + with open(f"{config_dir}/config.ini", "w") as config_ini: + config_parser = ConfigParser() + config_parser.optionxform = str + config_parser["DEFAULT"] = { + "SA_TOKEN": self._token, + "LOGGING_LEVEL": "DEBUG", + } + config_parser.write(config_ini) + for kwargs in ( + {}, + {"config_path": f"~/{Path(config_dir).name}/config.ini"}, + ): + sa = SAClient(**kwargs) + assert sa.controller._config.API_TOKEN == self._token + assert sa.controller._config.LOGGING_LEVEL == "DEBUG" + assert sa.controller._config.API_URL == constants.BACKEND_URL + assert get_team_use_case.call_args_list[0].kwargs["team_id"] == int( + self._token.split("=")[-1] + ) + @patch("lib.core.usecases.GetTeamUseCase") @patch.dict(os.environ, {"SA_URL": "SOME_URL", "SA_TOKEN": "SOME_TOKEN=123"}) def test_init_env(self, get_team_use_case): @@ -103,7 +129,6 @@ def test_init_env_invalid_token(self): def test_init_via_config_ini_invalid_token(self): with tempfile.TemporaryDirectory() as config_dir: - constants.HOME_PATH = config_dir config_ini_path = f"{config_dir}/config.ini" config_json_path = f"{config_dir}/config.json" with patch("lib.core.CONFIG_INI_FILE_LOCATION", config_ini_path), patch(