From b760c2fee736dc8d8008bfaa29d12bd497683829 Mon Sep 17 00:00:00 2001 From: Narek Mkhitaryan Date: Tue, 16 Jul 2024 17:28:05 +0400 Subject: [PATCH] changed classes interface to sdk_core --- .../lib/app/interface/sdk_interface.py | 152 +++++++----- .../lib/core/usecases/__init__.py | 1 - .../lib/core/usecases/classes.py | 221 ------------------ src/superannotate/lib/core/usecases/models.py | 2 +- .../lib/infrastructure/controller.py | 90 ------- ...est_uopload_annotations_without_classes.py | 4 +- .../classes/test_create_annotation_class.py | 53 ++--- ...te_annotation_classes_from_classes_json.py | 98 ++++++-- .../classes/test_delete_annotation_class.py | 24 ++ .../custom_fields/test_custom_schema.py | 1 - 10 files changed, 204 insertions(+), 442 deletions(-) delete mode 100644 src/superannotate/lib/core/usecases/classes.py create mode 100644 tests/integration/classes/test_delete_annotation_class.py diff --git a/src/superannotate/lib/app/interface/sdk_interface.py b/src/superannotate/lib/app/interface/sdk_interface.py index 887039acd..ad8c4c4be 100644 --- a/src/superannotate/lib/app/interface/sdk_interface.py +++ b/src/superannotate/lib/app/interface/sdk_interface.py @@ -60,12 +60,10 @@ from lib.core.types import MLModel from lib.core.types import PriorityScoreEntity -from lib.core.pydantic_v1 import ValidationError from lib.core.pydantic_v1 import constr from lib.core.pydantic_v1 import conlist from lib.core.pydantic_v1 import parse_obj_as from lib.infrastructure.utils import extract_project_folder -from lib.infrastructure.validators import wrap_error from superannotate_core.core.entities import BaseItemEntity from superannotate_core.app import Project, Folder @@ -743,7 +741,9 @@ def get_project_workflow(self, project: Union[str, dict]): return workflow.data def search_annotation_classes( - self, project: Union[NotEmptyStr, dict], name_contains: Optional[str] = None + self, + project: Union[NotEmptyStr, dict, Project], + name_contains: Optional[str] = None, ): """Searches annotation classes by name_prefix (case-insensitive) @@ -757,17 +757,20 @@ def search_annotation_classes( :return: annotation classes of the project :rtype: list of dicts """ - project_name, folder_name = extract_project_folder(project) + if isinstance(project, Project): + project = project.dict() + + project_name, _ = extract_project_folder(project) project = self.controller.get_project(project_name) - condition = Condition("project_id", project.id, EQ) - if name_contains: - condition &= Condition("name", name_contains, EQ) & Condition( - "pattern", True, EQ - ) - response = self.controller.annotation_classes.list(condition) - if response.errors: - raise AppException(response.errors) - return response.data + condition = ( + Condition("name", name_contains, EQ) & Condition("pattern", True, EQ) + if name_contains + else None + ) + return [ + annotation_class.dict() + for annotation_class in project.list_annotation_classes(condition) + ] def set_project_status(self, project: NotEmptyStr, status: PROJECT_STATUS): """Set project status @@ -1353,8 +1356,7 @@ def upload_video_to_project( def create_annotation_class( self, - # todo project: Union[Project, NotEmptyStr], - project: Union[NotEmptyStr], + project: Union[NotEmptyStr, dict, Project], name: NotEmptyStr, color: NotEmptyStr, attribute_groups: Optional[List[AttributeGroupSchema]] = None, @@ -1447,38 +1449,41 @@ def create_annotation_class( ) """ - # todo if isinstance(project, Project): - # if isinstance(project, Project): - # project = project.dict() - _class_type = ClassTypeEnum(class_type) - # try: - # annotation_class = AnnotationClass( - # { - # "name": name, - # "color": color, - # "type": ClassTypeEnum(class_type), - # "attribute_groups": attribute_groups, - # } - # ) - # except ValidationError as e: - # raise AppException(wrap_error(e)) - project = self.controller.get_project(project) + if isinstance(project, Project): + project = project.dict() + + project_name, _ = extract_project_folder(project) + project = self.controller.get_project(project_name) + + if project.type == ProjectType.Pixel.value and any( + "default_value" in attr_group.keys() for attr_group in attribute_groups + ): + raise AppException( + 'The "default_value" key is not supported for project type Pixel.' + ) + + _class_type = ClassTypeEnum.get_value(class_type) if ( project.type != ProjectType.Document - and _class_type == ClassTypeEnum.RELATIONSHIP + and _class_type == ClassTypeEnum.relationship ): raise AppException( f"{class_type} class type is not supported in {project.type.name} project." ) - project.create_annotation_class( + annotation_class = project.create_annotation_class( name=name, color=color, class_type=_class_type, attribute_groups=attribute_groups, ) + if annotation_class: + return annotation_class.dict() + raise AppException("Failed to create annotation class") def delete_annotation_class( - self, project: NotEmptyStr, annotation_class: Union[dict, NotEmptyStr] + self, + project: Union[NotEmptyStr, dict, Project], + annotation_class: Union[dict, NotEmptyStr], ): """Deletes annotation class from project @@ -1489,24 +1494,32 @@ def delete_annotation_class( :type annotation_class: str or dict """ - if isinstance(annotation_class, str): - try: - annotation_class = AnnotationClassEntity( - name=annotation_class, - color="#ffffff", # noqa Random, just need to serialize - ) - except ValidationError as e: - raise AppException(wrap_error(e)) + if isinstance(annotation_class, dict) and "name" in annotation_class.keys(): + class_name = annotation_class["name"] + elif isinstance(annotation_class, str): + class_name = annotation_class else: - annotation_class = AnnotationClassEntity(**annotation_class) - project = self.controller.projects.get_by_name(project).data + raise AppException("Invalid value provided for annotation_class.") - self.controller.annotation_classes.delete( - project=project, annotation_class=annotation_class - ) + if isinstance(project, Project): + project = project.dict() + + project_name, _ = extract_project_folder(project) + project = self.controller.get_project(project_name) + + condition = Condition("name", class_name, EQ) & Condition("pattern", True, EQ) + annotation_classes = project.list_annotation_classes(condition=condition) + if annotation_classes: + class_to_delete = annotation_classes[0] + logger.info( + "Deleting annotation class from project %s with name %s", + project.name, + class_to_delete.name, + ) + project.delete_annotation_class(class_id=class_to_delete.id) def download_annotation_classes_json( - self, project: NotEmptyStr, folder: Union[str, Path] + self, project: Union[NotEmptyStr, dict, Project], folder: Union[str, Path] ): """Downloads project classes.json to folder @@ -1519,20 +1532,28 @@ def download_annotation_classes_json( :return: path of the download file :rtype: str """ + if isinstance(project, Project): + project = project.dict() - project = self.controller.projects.get_by_name(project).data - response = self.controller.annotation_classes.download( - project=project, download_path=folder + project_name, _ = extract_project_folder(project) + project = self.controller.get_project(project_name) + logger.info( + f"Downloading classes.json from project {project.name} to folder {str(folder)}." ) - if response.errors: - raise AppException(response.errors) - return response.data + annotation_classes: List[dict] = [ + annotation_class.dict() + for annotation_class in project.list_annotation_classes() + ] + json_path = f"{folder}/classes.json" + with open(json_path, "w") as f: + json.dump(annotation_classes, f, indent=4) + return json_path def create_annotation_classes_from_classes_json( self, - project: Union[NotEmptyStr, dict], + project: Union[NotEmptyStr, dict, Project], classes_json: Union[List[AnnotationClassEntity], str, Path], - from_s3_bucket=False, + from_s3_bucket: str = None, ): """Creates annotation classes in project from a SuperAnnotate format annotation classes.json. @@ -1549,7 +1570,7 @@ def create_annotation_classes_from_classes_json( :return: list of created annotation class metadatas :rtype: list of dicts """ - if isinstance(classes_json, str) or isinstance(classes_json, Path): + if isinstance(classes_json, (str, Path)): if from_s3_bucket: from_session = boto3.Session() from_s3 = from_session.resource("s3") @@ -1561,12 +1582,12 @@ def create_annotation_classes_from_classes_json( else: data = open(classes_json, encoding="utf-8") classes_json = json.load(data) - # try: - # # annotation_classes = parse_obj_as(List[AnnotationClassEntity], classes_json) - # annotation_classes = [AnnotationClassEntity.from_json(i) for i in classes_json] - # except ValidationError as _: - # raise AppException("Couldn't validate annotation classes.") - project = self.controller.get_project(project) + + if isinstance(project, Project): + project = project.dict() + + project_name, _ = extract_project_folder(project) + project = self.controller.get_project(project_name) annotation_classes = project.create_annotation_classes(classes_json) return [i.dict() for i in annotation_classes] @@ -1684,7 +1705,10 @@ def download_image( return response.data def upload_annotations( - self, project: NotEmptyStr, annotations: List[dict], keep_status: bool = False + self, + project: Union[NotEmptyStr, dict], + annotations: List[dict], + keep_status: bool = False, ): """Uploads a list of annotation dicts as annotations to the SuperAnnotate directory. diff --git a/src/superannotate/lib/core/usecases/__init__.py b/src/superannotate/lib/core/usecases/__init__.py index 7141f56b0..ff977e8fc 100644 --- a/src/superannotate/lib/core/usecases/__init__.py +++ b/src/superannotate/lib/core/usecases/__init__.py @@ -1,5 +1,4 @@ from lib.core.usecases.annotations import * # noqa: F403 F401 -from lib.core.usecases.classes import * # noqa: F403 F401 from lib.core.usecases.images import * # noqa: F403 F401 from lib.core.usecases.integrations import * # noqa: F403 F401 from lib.core.usecases.items import * # noqa: F403 F401 diff --git a/src/superannotate/lib/core/usecases/classes.py b/src/superannotate/lib/core/usecases/classes.py deleted file mode 100644 index e0642d3ab..000000000 --- a/src/superannotate/lib/core/usecases/classes.py +++ /dev/null @@ -1,221 +0,0 @@ -import json -import logging -from typing import List - -from lib.core.conditions import Condition -from lib.core.conditions import CONDITION_EQ as EQ -from lib.core.entities import AnnotationClassEntity -from lib.core.entities import ProjectEntity -from lib.core.entities.classes import GroupTypeEnum -from lib.core.enums import ProjectType -from lib.core.exceptions import AppException -from lib.core.serviceproviders import BaseServiceProvider -from lib.core.usecases.base import BaseUseCase - -logger = logging.getLogger("sa") - - -class GetAnnotationClassesUseCase(BaseUseCase): - def __init__( - self, - service_provider: BaseServiceProvider, - condition: Condition = None, - ): - super().__init__() - self._service_provider = service_provider - self._condition = condition - - def execute(self): - response = self._service_provider.annotation_classes.list(self._condition) - if response.ok: - classes = [ - entity.dict(by_alias=True, exclude_unset=True) - for entity in response.data - ] - self._response.data = classes - else: - self._response.errors = response.error - return self._response - - -class CreateAnnotationClassUseCase(BaseUseCase): - def __init__( - self, - service_provider: BaseServiceProvider, - annotation_class: AnnotationClassEntity, - project: ProjectEntity, - ): - super().__init__() - self._service_provider = service_provider - self._annotation_class = annotation_class - self._project = project - - def _is_unique(self): - annotation_classes = self._service_provider.annotation_classes.list( - Condition("project_id", self._project.id, EQ) - ).data - return not any( - [ - True - for annotation_class in annotation_classes - if annotation_class.name == self._annotation_class.name - ] - ) - - def validate_project_type(self): - if ( - self._project.type == ProjectType.PIXEL - and self._annotation_class.type == "tag" - ): - raise AppException( - "Predefined tagging functionality is not supported for projects" - f" of type {ProjectType.get_name(self._project.type)}." - ) - if self._project.type != ProjectType.VECTOR: - for g in self._annotation_class.attribute_groups: - if g.group_type == GroupTypeEnum.OCR: - raise AppException( - f"OCR attribute group is not supported for project type " - f"{ProjectType.get_name(self._project.type)}." - ) - - def validate_default_value(self): - if self._project.type == ProjectType.PIXEL.value and any( - getattr(attr_group, "default_value", None) - for attr_group in getattr(self._annotation_class, "attribute_groups", []) - ): - raise AppException( - 'The "default_value" key is not supported for project type Pixel.' - ) - - def execute(self): - if self.is_valid(): - if self._is_unique(): - response = self._service_provider.annotation_classes.create_multiple( - project=self._project, - classes=[self._annotation_class], - ) - if response.ok: - self._response.data = response.data[0] - else: - self._response.errors = AppException( - response.error.replace(". ", ".\n") - ) - else: - logger.error("This class name already exists. Skipping.") - return self._response - - -class CreateAnnotationClassesUseCase(BaseUseCase): - CHUNK_SIZE = 500 - - def __init__( - self, - service_provider: BaseServiceProvider, - annotation_classes: List[AnnotationClassEntity], - project: ProjectEntity, - ): - super().__init__() - self._project = project - self._service_provider = service_provider - self._annotation_classes = annotation_classes - - def validate_project_type(self): - if self._project.type != ProjectType.VECTOR: - for c in self._annotation_classes: - if self._project.type == ProjectType.PIXEL and c.type == "tag": - raise AppException( - f"Predefined tagging functionality is not supported" - f" for projects of type {ProjectType.get_name(self._project.type)}." - ) - for g in c.attribute_groups: - if g.group_type == GroupTypeEnum.OCR: - raise AppException( - f"OCR attribute group is not supported for project type " - f"{ProjectType.get_name(self._project.type)}." - ) - - def validate_default_value(self): - if self._project.type == ProjectType.PIXEL.value: - for annotation_class in self._annotation_classes: - if any( - getattr(attr_group, "default_value", None) - for attr_group in getattr(annotation_class, "attribute_groups", []) - ): - raise AppException( - 'The "default_value" key is not supported for project type Pixel.' - ) - - def execute(self): - if self.is_valid(): - existing_annotation_classes = ( - self._service_provider.annotation_classes.list( - Condition("project_id", self._project.id, EQ) - ).data - ) - existing_classes_name = [i.name for i in existing_annotation_classes] - unique_annotation_classes = [] - for annotation_class in self._annotation_classes: - if annotation_class.name not in existing_classes_name: - unique_annotation_classes.append(annotation_class) - not_unique_classes_count = len(self._annotation_classes) - len( - unique_annotation_classes - ) - if not_unique_classes_count: - logger.warning( - f"{not_unique_classes_count} annotation classes already exist.Skipping." - ) - created = [] - chunk_failed = False - # this is in reverse order because of the front-end - for i in range(len(unique_annotation_classes), 0, -self.CHUNK_SIZE): - response = self._service_provider.annotation_classes.create_multiple( - project=self._project, - classes=unique_annotation_classes[i - self.CHUNK_SIZE : i], # noqa - ) - if response.ok: - created.extend(response.data) - else: - logger.debug(response.error) - chunk_failed = True - if created: - logger.info( - f"{len(created)} annotation classes were successfully created in {self._project.name}." - ) - if chunk_failed: - self._response.errors = AppException( - "The classes couldn't be validated." - ) - self._response.data = created - return self._response - - -# TODO delete -class DownloadAnnotationClassesUseCase(BaseUseCase): - def __init__( - self, - download_path: str, - project: ProjectEntity, - service_provider: BaseServiceProvider, - ): - super().__init__() - self._download_path = download_path - self._project = project - self._service_provider = service_provider - - def execute(self): - logger.info( - f"Downloading classes.json from project {self._project.name} to folder {str(self._download_path)}." - ) - response = self._service_provider.annotation_classes.list( - Condition("project_id", self._project.id, EQ) - ) - if response.ok: - classes = [ - entity.dict(by_alias=True, exclude_unset=True) - for entity in response.data - ] - json_path = f"{self._download_path}/classes.json" - json.dump(classes, open(json_path, "w"), indent=4) - self._response.data = json_path - return self._response diff --git a/src/superannotate/lib/core/usecases/models.py b/src/superannotate/lib/core/usecases/models.py index 4f64535c8..75c3062d9 100644 --- a/src/superannotate/lib/core/usecases/models.py +++ b/src/superannotate/lib/core/usecases/models.py @@ -30,7 +30,6 @@ from lib.core.usecases.annotations import DownloadAnnotations from lib.core.usecases.base import BaseReportableUseCase from lib.core.usecases.base import BaseUseCase -from lib.core.usecases.classes import DownloadAnnotationClassesUseCase logger = logging.getLogger("sa") @@ -321,6 +320,7 @@ def execute(self): return self._response +# TODO fix class ConsensusUseCase(BaseUseCase): def __init__( self, diff --git a/src/superannotate/lib/infrastructure/controller.py b/src/superannotate/lib/infrastructure/controller.py index 891b44c6c..1929aaed8 100644 --- a/src/superannotate/lib/infrastructure/controller.py +++ b/src/superannotate/lib/infrastructure/controller.py @@ -24,12 +24,10 @@ from lib.core.entities import SettingEntity from lib.core.entities import TeamEntity from lib.core.entities import UserEntity -from lib.core.entities.classes import AnnotationClassEntity from lib.core.entities.integrations import IntegrationEntity from lib.core.exceptions import AppException from lib.core.reporter import Reporter from lib.core.response import Response -from lib.infrastructure.helpers import timed_lru_cache from lib.infrastructure.repositories import S3Repository from lib.infrastructure.serviceprovider import ServiceProvider from lib.infrastructure.services.http_client import HttpClient @@ -205,91 +203,6 @@ def upload_priority_scores( return use_case.execute() -# TODO delete -class AnnotationClassManager(BaseManager): - @timed_lru_cache(seconds=3600) - def __get_auth_data(self, project: ProjectEntity, folder: FolderEntity): - response = self.service_provider.get_s3_upload_auth_token(project, folder) - if not response.ok: - raise AppException(response.error) - return response.data - - def _get_s3_repository(self, project: ProjectEntity, folder: FolderEntity): - auth_data = self.__get_auth_data(project, folder) - return S3Repository( - auth_data["accessKeyId"], - auth_data["secretAccessKey"], - auth_data["sessionToken"], - auth_data["bucket"], - auth_data["region"], - ) - - def create(self, project: ProjectEntity, annotation_class: AnnotationClassEntity): - use_case = usecases.CreateAnnotationClassUseCase( - annotation_class=annotation_class, - project=project, - service_provider=self.service_provider, - ) - return use_case.execute() - - def create_multiple( - self, project: ProjectEntity, annotation_classes: List[AnnotationClassEntity] - ): - use_case = usecases.CreateAnnotationClassesUseCase( - service_provider=self.service_provider, - annotation_classes=annotation_classes, - project=project, - ) - return use_case.execute() - - def list(self, condition: Condition): - use_case = usecases.GetAnnotationClassesUseCase( - service_provider=self.service_provider, - condition=condition, - ) - return use_case.execute() - - def delete(self, project: ProjectEntity, annotation_class: AnnotationClassEntity): - use_case = usecases.DeleteAnnotationClassUseCase( - annotation_class=annotation_class, - project=project, - service_provider=self.service_provider, - ) - return use_case.execute() - - def copy_multiple( - self, - source_project: ProjectEntity, - source_folder: FolderEntity, - source_item: BaseItemEntity, - destination_project: ProjectEntity, - destination_folder: FolderEntity, - destination_item: BaseItemEntity, - ): - use_case = usecases.CopyImageAnnotationClasses( - from_project=source_project, - from_folder=source_folder, - from_image=source_item, - to_project=destination_project, - to_folder=destination_folder, - to_image=destination_item, - service_provider=self.service_provider, - from_project_s3_repo=self._get_s3_repository(source_project, source_folder), - to_project_s3_repo=self._get_s3_repository( - destination_project, destination_folder - ), - ) - return use_case.execute() - - def download(self, project: ProjectEntity, download_path: str): - use_case = usecases.DownloadAnnotationClassesUseCase( - project=project, - download_path=download_path, - service_provider=self.service_provider, - ) - return use_case.execute() - - class ItemManager(BaseManager): def get_by_name( self, @@ -676,9 +589,6 @@ def __init__(self, config: ConfigEntity): ) self._user = self.get_current_user() self._team = self.get_team().data - self.annotation_classes = AnnotationClassManager( - self.service_provider, self._session - ) self.projects = ProjectManager(self.service_provider, self._session) self.items = ItemManager(self.service_provider, self._session) self.annotations = AnnotationManager( diff --git a/tests/integration/annotations/test_uopload_annotations_without_classes.py b/tests/integration/annotations/test_uopload_annotations_without_classes.py index 4baa72b0d..55d0efc9b 100644 --- a/tests/integration/annotations/test_uopload_annotations_without_classes.py +++ b/tests/integration/annotations/test_uopload_annotations_without_classes.py @@ -51,5 +51,5 @@ def test_annotation_upload(self): classes_path = sa.download_annotation_classes_json( self.PROJECT_NAME, classes_dir ) - classes_json = json.load(open(classes_path)) - self.assertEqual(classes_json[0]["type"], "tag") + classes = json.load(open(classes_path)) + self.assertEqual(classes[0]["type"], "tag") diff --git a/tests/integration/classes/test_create_annotation_class.py b/tests/integration/classes/test_create_annotation_class.py index ac0fa5148..bf7923414 100644 --- a/tests/integration/classes/test_create_annotation_class.py +++ b/tests/integration/classes/test_create_annotation_class.py @@ -4,7 +4,7 @@ import pytest from src.superannotate import AppException from src.superannotate import SAClient -from src.superannotate.lib.core.entities.classes import AnnotationClassEntity +from superannotate_core.core.exceptions import SAException from tests import DATA_SET_PATH from tests.integration.base import BaseTestCase @@ -105,10 +105,14 @@ def test_create_radio_annotation_class_attr_required(self): self.assertEqual(msg, '"classes[0].attribute_groups[0].attributes" is required') def test_create_annotation_class_backend_errors(self): - - response = sa.controller.annotation_classes.create( - sa.controller.projects.get_by_name(self.PROJECT_NAME).data, - AnnotationClassEntity( + validation_errors = [ + """classes[0].attribute_groups[0].attributes" is required.""", + """classes[0].attribute_groups[1].attributes" is required.""", + """classes[0].attribute_groups[2].default_value" must be a string""", + ] + try: + sa.create_annotation_class( + project=self.PROJECT_NAME, name="t", color="blue", attribute_groups=[ @@ -121,41 +125,10 @@ def test_create_annotation_class_backend_errors(self): "attributes": [], }, ], - ), - ) - - assert ( - response.errors - == '"classes[0].attribute_groups[0].attributes" is required.\n' - '"classes[0].attribute_groups[1].attributes" is required.\n' - '"classes[0].attribute_groups[2].default_value" must be a string' - ) - - def test_create_annotation_classes_with_empty_default_attribute(self): - sa.create_annotation_classes_from_classes_json( - self.PROJECT_NAME, - classes_json=[ - { - "name": "Personal vehicle", - "color": "#ecb65f", - "createdAt": "2020-10-12T11:35:20.000Z", - "updatedAt": "2020-10-12T11:48:19.000Z", - "attribute_groups": [ - { - "name": "test", - "group_type": "radio", - "attributes": [ - {"name": "Car"}, - {"name": "Track"}, - {"name": "Bus"}, - ], - } - ], - } - ], - ) - classes = sa.search_annotation_classes(self.PROJECT_NAME) - assert classes[0]["attribute_groups"][0]["default_value"] is None + ) + except SAException as e: + for error in validation_errors: + self.assertIn(error, str(e)) def test_class_creation_type(self): with tempfile.TemporaryDirectory() as tmpdir_name: diff --git a/tests/integration/classes/test_create_annotation_classes_from_classes_json.py b/tests/integration/classes/test_create_annotation_classes_from_classes_json.py index f80705183..33dc58005 100644 --- a/tests/integration/classes/test_create_annotation_classes_from_classes_json.py +++ b/tests/integration/classes/test_create_annotation_classes_from_classes_json.py @@ -48,6 +48,7 @@ def test_create_annotation_class_from_json(self): ) self.assertEqual(len(sa.search_annotation_classes(self.PROJECT_NAME)), 4) + # TODO failed after SDK_core integration (check validation in future) def test_invalid_json(self): try: sa.create_annotation_classes_from_classes_json( @@ -119,6 +120,7 @@ def test_create_annotation_class(self): "Predefined tagging functionality is not supported for projects of type Video.", ) + # TODO failed after SDK_core integration (check validation in future) def test_create_annotation_class_via_json_and_ocr_group_type(self): with tempfile.TemporaryDirectory() as tmpdir_name: temp_path = f"{tmpdir_name}/new_classes.json" @@ -158,6 +160,44 @@ def test_create_annotation_class_via_json_and_ocr_group_type(self): self.PROJECT_NAME, temp_path ) + def test_create_annotation_classes_with_empty_default_attribute(self): + with tempfile.TemporaryDirectory() as tmpdir_name: + temp_path = f"{tmpdir_name}/new_classes.json" + with open(temp_path, "w") as new_classes: + new_classes.write( + """ + [ + { + "id":56820, + "project_id":7617, + "name":"Personal vehicle", + "color":"#547497", + "count":18, + "type": "tag", + "attribute_groups":[ + { + "id":21448, + "class_id":56820, + "name":"Large", + "group_type": "radio", + "attributes":[ + {"name": "Car"}, + {"name": "Track"}, + {"name": "Bus"} + ] + } + ] + } + ] + """ + ) + sa.create_annotation_classes_from_classes_json( + self.PROJECT_NAME, + classes_json=temp_path, + ) + classes = sa.search_annotation_classes(self.PROJECT_NAME) + assert classes[0]["attribute_groups"][0]["default_value"] is None + class TestPixelCreateAnnotationClass(BaseTestCase): PROJECT_NAME = "TestCreateAnnotationClassPixel" @@ -169,30 +209,44 @@ class TestPixelCreateAnnotationClass(BaseTestCase): def large_json_path(self): return os.path.join(DATA_SET_PATH, self.TEST_LARGE_CLASSES_JSON) + # TODO failed after SDK_core integration (check validation in future) def test_create_annotation_classes_with_default_attribute(self): with self.assertRaisesRegexp( AppException, 'The "default_value" key is not supported for project type Pixel.', ): - sa.create_annotation_classes_from_classes_json( - self.PROJECT_NAME, - classes_json=[ - { - "name": "Personal vehicle", - "color": "#ecb65f", - "createdAt": "2020-10-12T11:35:20.000Z", - "updatedAt": "2020-10-12T11:48:19.000Z", - "attribute_groups": [ - { - "name": "test", - "attributes": [ - {"name": "Car"}, - {"name": "Track"}, - {"name": "Bus"}, - ], - "default_value": "Bus", - } - ], - } - ], - ) + with tempfile.TemporaryDirectory() as tmpdir_name: + temp_path = f"{tmpdir_name}/new_classes.json" + with open(temp_path, "w") as new_classes: + new_classes.write( + """ + [ + { + "id":56820, + "project_id":7617, + "name":"Personal vehicle", + "color":"#547497", + "count":18, + "type": "tag", + "attribute_groups":[ + { + "id":21448, + "class_id":56820, + "name":"Large", + "group_type": "radio", + "attributes":[ + {"name": "Car"}, + {"name": "Track"}, + {"name": "Bus"} + ], + "default_value": "Bus" + } + ] + } + ] + """ + ) + sa.create_annotation_classes_from_classes_json( + self.PROJECT_NAME, + classes_json=temp_path, + ) diff --git a/tests/integration/classes/test_delete_annotation_class.py b/tests/integration/classes/test_delete_annotation_class.py new file mode 100644 index 000000000..e4b7957cc --- /dev/null +++ b/tests/integration/classes/test_delete_annotation_class.py @@ -0,0 +1,24 @@ +from src.superannotate import SAClient +from tests.integration.base import BaseTestCase + + +sa = SAClient() + + +class TestVectorAnnotationClassesDelete(BaseTestCase): + PROJECT_NAME = "TestVectorAnnotationClassesDelete" + PROJECT_DESCRIPTION = "test description" + PROJECT_TYPE = "Vector" + + def setUp(self, *args, **kwargs): + super().setUp() + sa.create_annotation_class( + self.PROJECT_NAME, "test_annotation_class", "#FFFFFF" + ) + classes = sa.search_annotation_classes(self.PROJECT_NAME) + self.assertEqual(len(classes), 1) + + def test_delete_annotation_class(self): + sa.delete_annotation_class(self.PROJECT_NAME, "test_annotation_class") + classes = sa.search_annotation_classes(self.PROJECT_NAME) + self.assertEqual(len(classes), 0) diff --git a/tests/integration/custom_fields/test_custom_schema.py b/tests/integration/custom_fields/test_custom_schema.py index 1c621405f..eaa3b0e78 100644 --- a/tests/integration/custom_fields/test_custom_schema.py +++ b/tests/integration/custom_fields/test_custom_schema.py @@ -1,6 +1,5 @@ import copy -from src.superannotate import AppException from src.superannotate import SAClient from superannotate_core.core.exceptions import SAException from tests.integration.base import BaseTestCase