diff --git a/pytest.ini b/pytest.ini index d9f7f6cc3..c0f66b58e 100644 --- a/pytest.ini +++ b/pytest.ini @@ -3,4 +3,4 @@ minversion = 3.7 log_cli=true python_files = test_*.py ;pytest_plugins = ['pytest_profiling'] -;addopts = -n 6 --dist loadscope +addopts = -n 6 --dist loadscope diff --git a/src/superannotate/lib/core/serviceproviders.py b/src/superannotate/lib/core/serviceproviders.py index 697408a13..207f17775 100644 --- a/src/superannotate/lib/core/serviceproviders.py +++ b/src/superannotate/lib/core/serviceproviders.py @@ -851,13 +851,17 @@ def invite_contributors( @abstractmethod def list_custom_field_names( - self, pk, entity: CustomFieldEntityEnum, parent: CustomFieldEntityEnum + self, + context: dict, + entity: CustomFieldEntityEnum, + parent: CustomFieldEntityEnum, ) -> List[str]: raise NotImplementedError @abstractmethod def get_custom_field_id( self, + context: dict, field_name: str, entity: CustomFieldEntityEnum, parent: CustomFieldEntityEnum, @@ -867,6 +871,7 @@ def get_custom_field_id( @abstractmethod def get_custom_field_name( self, + context: dict, field_id: int, entity: CustomFieldEntityEnum, parent: CustomFieldEntityEnum, @@ -884,6 +889,9 @@ def get_custom_field_component_id( @abstractmethod def get_custom_fields_templates( - self, entity: CustomFieldEntityEnum, parent: CustomFieldEntityEnum + self, + context: dict, + entity: CustomFieldEntityEnum, + parent: CustomFieldEntityEnum, ): raise NotImplementedError diff --git a/src/superannotate/lib/core/usecases/projects.py b/src/superannotate/lib/core/usecases/projects.py index 7e5de2148..26f35d938 100644 --- a/src/superannotate/lib/core/usecases/projects.py +++ b/src/superannotate/lib/core/usecases/projects.py @@ -154,8 +154,9 @@ def execute(self): else: project.users = [] if self._include_custom_fields: + context = {"team_id": self._project.team_id} custom_fields_names = self._service_provider.list_custom_field_names( - self._project.team_id, + context, entity=CustomFieldEntityEnum.PROJECT, parent=CustomFieldEntityEnum.TEAM, ) @@ -173,6 +174,7 @@ def execute(self): custom_fields_name_value_map = {} for name in custom_fields_names: field_id = self._service_provider.get_custom_field_id( + context, name, entity=CustomFieldEntityEnum.PROJECT, parent=CustomFieldEntityEnum.TEAM, diff --git a/src/superannotate/lib/infrastructure/controller.py b/src/superannotate/lib/infrastructure/controller.py index 619a58d7d..7e7919f91 100644 --- a/src/superannotate/lib/infrastructure/controller.py +++ b/src/superannotate/lib/infrastructure/controller.py @@ -83,14 +83,10 @@ def serialize_custom_fields( entity: CustomFieldEntityEnum, parent_entity: CustomFieldEntityEnum, ) -> List[dict]: - pk = ( - project_id - if entity == CustomFieldEntityEnum.PROJECT - else (team_id if parent_entity == CustomFieldEntityEnum.TEAM else project_id) - ) + context = {"team_id": team_id, "project_id": project_id} existing_custom_fields = service_provider.list_custom_field_names( - pk, entity, parent=parent_entity + context, entity, parent=parent_entity ) for i in range(len(data)): if not data[i]: @@ -111,7 +107,7 @@ def serialize_custom_fields( field_value /= 1000 # Convert timestamp new_field_name = service_provider.get_custom_field_name( - field_id, entity=entity, parent=parent_entity + context, field_id, entity=entity, parent=parent_entity ) updated_fields[new_field_name] = field_value @@ -151,11 +147,12 @@ def set_custom_field_value( field_name: str, value: Any, ): - _context = {} + _context = {"team_id": self.service_provider.client.team_id} if entity == CustomFieldEntityEnum.PROJECT: _context["project_id"] = entity_id + template_id = self.service_provider.get_custom_field_id( - field_name, entity=entity, parent=parent_entity + _context, field_name, entity=entity, parent=parent_entity ) component_id = self.service_provider.get_custom_field_component_id( template_id, entity=entity, parent=parent_entity @@ -178,16 +175,17 @@ def set_custom_field_value( def list_users( self, include: List[Literal["custom_fields"]] = None, project=None, **filters ): + context = {"team_id": self.service_provider.client.team_id} if project: parent_entity = CustomFieldEntityEnum.PROJECT - project_id = project.id + project_id = context["project_id"] = project.id else: parent_entity = CustomFieldEntityEnum.TEAM project_id = None valid_fields = generate_schema( UserFilters.__annotations__, self.service_provider.get_custom_fields_templates( - CustomFieldEntityEnum.CONTRIBUTOR, parent=parent_entity + context, CustomFieldEntityEnum.CONTRIBUTOR, parent=parent_entity ), ) chain = QueryBuilderChain( @@ -574,7 +572,9 @@ def list_projects( valid_fields = generate_schema( ProjectFilters.__annotations__, self.service_provider.get_custom_fields_templates( - CustomFieldEntityEnum.PROJECT, parent=CustomFieldEntityEnum.TEAM + {"team_id": self.service_provider.client.team_id}, + CustomFieldEntityEnum.PROJECT, + parent=CustomFieldEntityEnum.TEAM, ), ) chain = QueryBuilderChain( diff --git a/src/superannotate/lib/infrastructure/query_builder.py b/src/superannotate/lib/infrastructure/query_builder.py index 60f07781a..4cc69046e 100644 --- a/src/superannotate/lib/infrastructure/query_builder.py +++ b/src/superannotate/lib/infrastructure/query_builder.py @@ -175,25 +175,18 @@ def __init__( self._team_id = team_id self._project_id = project_id - @property - def pk(self): - if self._entity == CustomFieldEntityEnum.PROJECT: - return self._project_id - if self._parent == CustomFieldEntityEnum.TEAM: - return self._team_id - return self._project_id - def _handle_custom_field_key(self, key) -> Tuple[str, str, Optional[str]]: + context = {"team_id": self._team_id, "project_id": self._project_id} for custom_field in sorted( self._service_provider.list_custom_field_names( - self.pk, self._entity, parent=self._parent + context, self._entity, parent=self._parent ), key=len, reverse=True, ): if custom_field in key: custom_field_id = self._service_provider.get_custom_field_id( - custom_field, entity=self._entity, parent=self._parent + context, custom_field, entity=self._entity, parent=self._parent ) component_id = self._service_provider.get_custom_field_component_id( custom_field_id, entity=self._entity, parent=self._parent diff --git a/src/superannotate/lib/infrastructure/serviceprovider.py b/src/superannotate/lib/infrastructure/serviceprovider.py index 7364951fe..79c5a38fb 100644 --- a/src/superannotate/lib/infrastructure/serviceprovider.py +++ b/src/superannotate/lib/infrastructure/serviceprovider.py @@ -24,6 +24,7 @@ from lib.infrastructure.services.telemetry_scoring import TelemetryScoringService from lib.infrastructure.services.work_management import WorkManagementService from lib.infrastructure.utils import CachedWorkManagementRepository +from lib.infrastructure.utils import EntityContext class ServiceProvider(BaseServiceProvider): @@ -72,17 +73,23 @@ def __init__(self, client: HttpClient): ) def get_custom_fields_templates( - self, entity: CustomFieldEntityEnum, parent: CustomFieldEntityEnum + self, + context: EntityContext, + entity: CustomFieldEntityEnum, + parent: CustomFieldEntityEnum, ): return self._cached_work_management_repository.list_templates( - self.client.team_id, entity=entity, parent=parent + context, entity=entity, parent=parent ) def list_custom_field_names( - self, pk, entity: CustomFieldEntityEnum, parent: CustomFieldEntityEnum + self, + context: EntityContext, + entity: CustomFieldEntityEnum, + parent: CustomFieldEntityEnum, ) -> List[str]: return self._cached_work_management_repository.list_custom_field_names( - pk, + context, entity=entity, parent=parent, ) @@ -96,22 +103,24 @@ def get_category_id( def get_custom_field_id( self, + context: EntityContext, field_name: str, entity: CustomFieldEntityEnum, parent: CustomFieldEntityEnum, ) -> int: return self._cached_work_management_repository.get_custom_field_id( - self.client.team_id, field_name, entity=entity, parent=parent + context, field_name, entity=entity, parent=parent ) def get_custom_field_name( self, + context: EntityContext, field_id: int, entity: CustomFieldEntityEnum, parent: CustomFieldEntityEnum, ) -> str: return self._cached_work_management_repository.get_custom_field_name( - self.client.team_id, field_id, entity=entity, parent=parent + context, field_id, entity=entity, parent=parent ) def get_custom_field_component_id( @@ -121,7 +130,7 @@ def get_custom_field_component_id( parent: CustomFieldEntityEnum, ) -> str: return self._cached_work_management_repository.get_custom_field_component_id( - self.client.team_id, field_id, entity=entity, parent=parent + {"team_id": self.client.team_id}, field_id, entity=entity, parent=parent ) def get_role_id(self, project: entities.ProjectEntity, role_name: str) -> int: diff --git a/src/superannotate/lib/infrastructure/services/work_management.py b/src/superannotate/lib/infrastructure/services/work_management.py index 0ab43a5ef..375a6932b 100644 --- a/src/superannotate/lib/infrastructure/services/work_management.py +++ b/src/superannotate/lib/infrastructure/services/work_management.py @@ -375,9 +375,7 @@ def set_custom_field_value( url=self.URL_SET_CUSTOM_ENTITIES.format(pk=entity_id), method="patch", headers={ - "x-sa-entity-context": self._generate_context( - team_id=self.client.team_id, **context - ), + "x-sa-entity-context": self._generate_context(**context), }, data={"customField": {"custom_field_values": {template_id: data}}}, params={ diff --git a/src/superannotate/lib/infrastructure/stream_data_handler.py b/src/superannotate/lib/infrastructure/stream_data_handler.py index 32c6e59c4..f8a3759ee 100644 --- a/src/superannotate/lib/infrastructure/stream_data_handler.py +++ b/src/superannotate/lib/infrastructure/stream_data_handler.py @@ -174,4 +174,4 @@ def _store_annotation(path, annotation: dict, callback: Callable = None): def _process_data(self, data): if data and self._map_function: return self._map_function(data) - return data \ No newline at end of file + return data diff --git a/src/superannotate/lib/infrastructure/utils.py b/src/superannotate/lib/infrastructure/utils.py index ce80988ae..c3934d833 100644 --- a/src/superannotate/lib/infrastructure/utils.py +++ b/src/superannotate/lib/infrastructure/utils.py @@ -1,6 +1,7 @@ import asyncio import logging import time +import typing from abc import ABC from abc import abstractmethod from functools import wraps @@ -24,6 +25,11 @@ logger = logging.getLogger("sa") +class EntityContext(typing.TypedDict, total=False): + team_id: int + project_id: Optional[int] + + def divide_to_chunks(it, size): it = iter(it) return iter(lambda: tuple(islice(it, size)), ()) @@ -324,78 +330,104 @@ def get_annotation_status_name(self, project, status_value: int) -> str: def get_custom_field_id( self, - team_id: int, + context: EntityContext, field_name: str, entity: CustomFieldEntityEnum, parent: CustomFieldEntityEnum, ) -> int: if entity == CustomFieldEntityEnum.PROJECT: - custom_field_data = self._project_custom_field_cache.get(team_id) + custom_field_data = self._project_custom_field_cache.get(context["team_id"]) else: if parent == CustomFieldEntityEnum.TEAM: - custom_field_data = self._team_user_custom_field_cache.get(team_id) + custom_field_data = self._team_user_custom_field_cache.get( + context["team_id"] + ) else: - custom_field_data = self._project_user_custom_field_cache.get(team_id) + custom_field_data = self._project_user_custom_field_cache.get( + context["project_id"] + ) if field_name in custom_field_data["custom_fields_name_id_map"]: return custom_field_data["custom_fields_name_id_map"][field_name] raise AppException("Invalid custom field name provided.") def get_custom_field_name( self, - team_id: int, + context: EntityContext, field_id: int, entity: CustomFieldEntityEnum, parent: CustomFieldEntityEnum, ) -> str: if entity == CustomFieldEntityEnum.PROJECT: - custom_field_data = self._project_custom_field_cache.get(team_id) + custom_field_data = self._project_custom_field_cache.get(context["team_id"]) else: if parent == CustomFieldEntityEnum.TEAM: - custom_field_data = self._team_user_custom_field_cache.get(team_id) + custom_field_data = self._team_user_custom_field_cache.get( + context["team_id"] + ) else: - custom_field_data = self._project_user_custom_field_cache.get(team_id) + custom_field_data = self._project_user_custom_field_cache.get( + context["project_id"] + ) if field_id in custom_field_data["custom_fields_id_name_map"]: return custom_field_data["custom_fields_id_name_map"][field_id] raise AppException("Invalid custom field ID provided.") def get_custom_field_component_id( self, - team_id: int, + context: EntityContext, field_id: int, entity: CustomFieldEntityEnum, parent: CustomFieldEntityEnum, ) -> str: if entity == CustomFieldEntityEnum.PROJECT: - custom_field_data = self._project_custom_field_cache.get(team_id) + custom_field_data = self._project_custom_field_cache.get(context["team_id"]) else: if parent == CustomFieldEntityEnum.TEAM: - custom_field_data = self._team_user_custom_field_cache.get(team_id) + custom_field_data = self._team_user_custom_field_cache.get( + context["team_id"] + ) else: - custom_field_data = self._project_user_custom_field_cache.get(team_id) + custom_field_data = self._project_user_custom_field_cache.get( + context["project_id"] + ) if field_id in custom_field_data["custom_fields_id_component_id_map"]: return custom_field_data["custom_fields_id_component_id_map"][field_id] raise AppException("Invalid custom field ID provided.") def list_custom_field_names( - self, pk: int, entity: CustomFieldEntityEnum, parent: CustomFieldEntityEnum + self, + context: EntityContext, + entity: CustomFieldEntityEnum, + parent: CustomFieldEntityEnum, ) -> list: if entity == CustomFieldEntityEnum.PROJECT: - custom_field_data = self._project_custom_field_cache.get(pk) + custom_field_data = self._project_custom_field_cache.get(context["team_id"]) else: if parent == CustomFieldEntityEnum.TEAM: - custom_field_data = self._team_user_custom_field_cache.get(pk) + custom_field_data = self._team_user_custom_field_cache.get( + context["team_id"] + ) else: - custom_field_data = self._project_user_custom_field_cache.get(pk) + custom_field_data = self._project_user_custom_field_cache.get( + context["project_id"] + ) return list(custom_field_data["custom_fields_name_id_map"].keys()) def list_templates( - self, pk: int, entity: CustomFieldEntityEnum, parent: CustomFieldEntityEnum + self, + context: EntityContext, + entity: CustomFieldEntityEnum, + parent: CustomFieldEntityEnum, ): if entity == CustomFieldEntityEnum.PROJECT: - return self._project_custom_field_cache.get(pk)["templates"] + return self._project_custom_field_cache.get(context["team_id"])["templates"] elif entity == CustomFieldEntityEnum.CONTRIBUTOR: if parent == CustomFieldEntityEnum.TEAM: - return self._team_user_custom_field_cache.get(pk)["templates"] + return self._team_user_custom_field_cache.get(context["team_id"])[ + "templates" + ] else: - return self._project_user_custom_field_cache.get(pk)["templates"] + return self._project_user_custom_field_cache.get(context["project_id"])[ + "templates" + ] raise AppException("Invalid entity provided.") diff --git a/tests/integration/work_management/test_project_custom_fields.py b/tests/integration/work_management/test_project_custom_fields.py index 88b3ad7f4..fea40a5b8 100644 --- a/tests/integration/work_management/test_project_custom_fields.py +++ b/tests/integration/work_management/test_project_custom_fields.py @@ -331,4 +331,5 @@ def test_list_projects_by_custom_invalid_field(self): # TODO BED issue (custom_field filter without join) def test_list_projects_by_custom_fields_without_join(self): self._set_custom_field_values() - assert sa.list_projects(custom_field__SDK_test_numeric=123) + with self.assertRaisesRegexp(AppException, "Internal server error"): + assert sa.list_projects(custom_field__SDK_test_numeric=123) diff --git a/tests/integration/work_management/test_user_scoring.py b/tests/integration/work_management/test_user_scoring.py index a4f8d6f9d..267d78f70 100644 --- a/tests/integration/work_management/test_user_scoring.py +++ b/tests/integration/work_management/test_user_scoring.py @@ -145,6 +145,7 @@ def test_set_get_scores(self): def test_list_users_with_scores(self): # list team users team_users = sa.list_users(include=["custom_fields"]) + print(team_users) for u in team_users: for s in SCORE_TEMPLATES: try: