From e23b92584985eea3e035abc5e592ef01c29d11a0 Mon Sep 17 00:00:00 2001 From: Vaghinak Basentsyan Date: Tue, 4 Mar 2025 17:51:09 +0400 Subject: [PATCH] Added ability to get project users tod --- src/superannotate/__init__.py | 2 +- .../lib/app/interface/sdk_interface.py | 28 +++++- .../lib/core/entities/work_managament.py | 30 ++++++ .../lib/core/serviceproviders.py | 32 +++++- .../lib/core/usecases/projects.py | 12 ++- .../lib/infrastructure/annotation_adapter.py | 4 +- .../lib/infrastructure/controller.py | 68 ++++++++++--- .../lib/infrastructure/query_builder.py | 18 +++- .../lib/infrastructure/serviceprovider.py | 35 +++++-- .../services/work_management.py | 31 ++++-- src/superannotate/lib/infrastructure/utils.py | 97 ++++++++++++++++--- .../test_user_custom_fields.py | 18 ++++ 12 files changed, 311 insertions(+), 64 deletions(-) diff --git a/src/superannotate/__init__.py b/src/superannotate/__init__.py index e6e8b7684..cec08ea20 100644 --- a/src/superannotate/__init__.py +++ b/src/superannotate/__init__.py @@ -3,7 +3,7 @@ import sys -__version__ = "4.4.32" +__version__ = "4.4.33dev1" os.environ.update({"sa_version": __version__}) sys.path.append(os.path.split(os.path.realpath(__file__))[0]) diff --git a/src/superannotate/lib/app/interface/sdk_interface.py b/src/superannotate/lib/app/interface/sdk_interface.py index 2dfe694b6..c6e091650 100644 --- a/src/superannotate/lib/app/interface/sdk_interface.py +++ b/src/superannotate/lib/app/interface/sdk_interface.py @@ -372,9 +372,18 @@ def set_user_custom_field( parent_entity=CustomFieldEntityEnum.TEAM, ) - def list_users(self, *, include: List[Literal["custom_fields"]] = None, **filters): + def list_users( + self, + *, + project: Union[int, str] = None, + include: List[Literal["custom_fields"]] = None, + **filters, + ): """ Search users by filtering criteria + :param project: Project name or ID, if provided, results will be for project-level, + otherwise results will be for team level. + :type project: str or int :param include: Specifies additional fields to be included in the response. @@ -454,9 +463,22 @@ def list_users(self, *, include: List[Literal["custom_fields"]] = None, **filter } ] """ - return BaseSerializer.serialize_iterable( - self.controller.work_management.list_users(include=include, **filters) + if project is not None: + if isinstance(project, int): + project = self.controller.get_project_by_id(project) + else: + project = self.controller.get_project(project) + response = BaseSerializer.serialize_iterable( + self.controller.work_management.list_users( + project=project, include=include, **filters + ) ) + if project: + for user in response: + user["role"] = self.controller.service_provider.get_role_name( + project, user["role"] + ) + return response def pause_user_activity( self, pk: Union[int, str], projects: Union[List[int], List[str], Literal["*"]] diff --git a/src/superannotate/lib/core/entities/work_managament.py b/src/superannotate/lib/core/entities/work_managament.py index 1539236bf..b78fab9de 100644 --- a/src/superannotate/lib/core/entities/work_managament.py +++ b/src/superannotate/lib/core/entities/work_managament.py @@ -119,3 +119,33 @@ def json(self, **kwargs): if "exclude" not in kwargs: kwargs["exclude"] = {"custom_fields"} return super().json(**kwargs) + + +class WMProjectUserEntity(TimedBaseModel): + id: Optional[int] + team_id: Optional[int] + role: int + email: Optional[str] + state: Optional[WMUserStateEnum] + custom_fields: Optional[dict] = Field(dict(), alias="customField") + + class Config: + extra = Extra.ignore + use_enum_names = True + + json_encoders = { + Enum: lambda v: v.value, + datetime.date: lambda v: v.isoformat(), + datetime.datetime: lambda v: v.isoformat(), + } + + @validator("custom_fields") + def custom_fields_transformer(cls, v): + if v and "custom_field_values" in v: + return v.get("custom_field_values", {}) + return {} + + def json(self, **kwargs): + if "exclude" not in kwargs: + kwargs["exclude"] = {"custom_fields"} + return super().json(**kwargs) diff --git a/src/superannotate/lib/core/serviceproviders.py b/src/superannotate/lib/core/serviceproviders.py index d8fbab28e..cf9d08441 100644 --- a/src/superannotate/lib/core/serviceproviders.py +++ b/src/superannotate/lib/core/serviceproviders.py @@ -147,7 +147,12 @@ def create_project_categories( @abstractmethod def list_users( - self, body_query: Query, chunk_size=100, include_custom_fields=False + self, + body_query: Query, + parent_entity: str = "Team", + chunk_size=100, + project_id: int = None, + include_custom_fields=False, ) -> WMUserListResponse: raise NotImplementedError @@ -804,23 +809,40 @@ def invite_contributors( raise NotImplementedError @abstractmethod - def list_custom_field_names(self, entity: CustomFieldEntityEnum) -> List[str]: + def list_custom_field_names( + self, pk, entity: CustomFieldEntityEnum, parent: CustomFieldEntityEnum + ) -> List[str]: raise NotImplementedError @abstractmethod def get_custom_field_id( - self, field_name: str, entity: CustomFieldEntityEnum + self, + field_name: str, + entity: CustomFieldEntityEnum, + parent: CustomFieldEntityEnum, ) -> int: raise NotImplementedError @abstractmethod def get_custom_field_name( - self, field_id: int, entity: CustomFieldEntityEnum + self, + field_id: int, + entity: CustomFieldEntityEnum, + parent: CustomFieldEntityEnum, ) -> str: raise NotImplementedError @abstractmethod def get_custom_field_component_id( - self, field_id: int, entity: CustomFieldEntityEnum + self, + field_id: int, + entity: CustomFieldEntityEnum, + parent: CustomFieldEntityEnum, ) -> str: raise NotImplementedError + + @abstractmethod + def get_custom_fields_templates( + self, entity: CustomFieldEntityEnum, parent: CustomFieldEntityEnum + ): + raise NotImplementedError diff --git a/src/superannotate/lib/core/usecases/projects.py b/src/superannotate/lib/core/usecases/projects.py index 9c9bf56a4..7e5de2148 100644 --- a/src/superannotate/lib/core/usecases/projects.py +++ b/src/superannotate/lib/core/usecases/projects.py @@ -155,7 +155,9 @@ def execute(self): project.users = [] if self._include_custom_fields: custom_fields_names = self._service_provider.list_custom_field_names( - entity=CustomFieldEntityEnum.PROJECT + self._project.team_id, + entity=CustomFieldEntityEnum.PROJECT, + parent=CustomFieldEntityEnum.TEAM, ) if custom_fields_names: project_custom_fields = ( @@ -171,7 +173,9 @@ def execute(self): custom_fields_name_value_map = {} for name in custom_fields_names: field_id = self._service_provider.get_custom_field_id( - name, entity=CustomFieldEntityEnum.PROJECT + name, + entity=CustomFieldEntityEnum.PROJECT, + parent=CustomFieldEntityEnum.TEAM, ) field_value = ( custom_fields_id_value_map[str(field_id)] @@ -180,7 +184,9 @@ def execute(self): ) # timestamp: convert milliseconds to seconds component_id = self._service_provider.get_custom_field_component_id( - field_id, entity=CustomFieldEntityEnum.PROJECT + field_id, + entity=CustomFieldEntityEnum.PROJECT, + parent=CustomFieldEntityEnum.TEAM, ) if ( field_value diff --git a/src/superannotate/lib/infrastructure/annotation_adapter.py b/src/superannotate/lib/infrastructure/annotation_adapter.py index c333231cf..13436510a 100644 --- a/src/superannotate/lib/infrastructure/annotation_adapter.py +++ b/src/superannotate/lib/infrastructure/annotation_adapter.py @@ -44,7 +44,9 @@ def get_component_value(self, component_id: str): return None def set_component_value(self, component_id: str, value: Any): - self.annotation.setdefault("data", {}).setdefault(component_id, {})["value"] = value + self.annotation.setdefault("data", {}).setdefault(component_id, {})[ + "value" + ] = value return self diff --git a/src/superannotate/lib/infrastructure/controller.py b/src/superannotate/lib/infrastructure/controller.py index 0e9f703b1..5c0179132 100644 --- a/src/superannotate/lib/infrastructure/controller.py +++ b/src/superannotate/lib/infrastructure/controller.py @@ -73,9 +73,22 @@ def build_condition(**kwargs) -> Condition: def serialize_custom_fields( - service_provider: ServiceProvider, data: List[dict], entity: CustomFieldEntityEnum + team_id: int, + project_id: int, + service_provider: ServiceProvider, + data: List[dict], + entity: CustomFieldEntityEnum, + parent_entity: CustomFieldEntityEnum, ) -> List[dict]: - existing_custom_fields = service_provider.list_custom_field_names(entity) + pk = ( + project_id + if entity == CustomFieldEntityEnum.PROJECT + else (team_id if parent_entity == CustomFieldEntityEnum.TEAM else project_id) + ) + + existing_custom_fields = service_provider.list_custom_field_names( + pk, entity, parent=parent_entity + ) for i in range(len(data)): if not data[i]: data[i] = {} @@ -85,7 +98,7 @@ def serialize_custom_fields( field_id = int(custom_field_name) try: component_id = service_provider.get_custom_field_component_id( - field_id, entity=entity + field_id, entity=entity, parent=parent_entity ) except AppException: # The component template can be deleted, but not from the entity, so it will be skipped. @@ -95,7 +108,7 @@ def serialize_custom_fields( field_value /= 1000 # Convert timestamp new_field_name = service_provider.get_custom_field_name( - field_id, entity=entity + field_id, entity=entity, parent=parent_entity ) updated_fields[new_field_name] = field_value @@ -139,10 +152,10 @@ def set_custom_field_value( if entity == CustomFieldEntityEnum.PROJECT: _context["project_id"] = entity_id template_id = self.service_provider.get_custom_field_id( - field_name, entity=entity + field_name, entity=entity, parent=parent_entity ) component_id = self.service_provider.get_custom_field_component_id( - template_id, entity=entity + template_id, entity=entity, parent=parent_entity ) # timestamp: convert seconds to milliseconds if component_id == CustomFieldType.DATE_PICKER.value and value is not None: @@ -159,40 +172,59 @@ def set_custom_field_value( context=_context, ) - def list_users(self, include: List[Literal["custom_fields"]] = None, **filters): + def list_users( + self, include: List[Literal["custom_fields"]] = None, project=None, **filters + ): + if project: + parent_entity = CustomFieldEntityEnum.PROJECT + project_id = project.id + else: + parent_entity = CustomFieldEntityEnum.TEAM + project_id = None valid_fields = generate_schema( UserFilters.__annotations__, self.service_provider.get_custom_fields_templates( - CustomFieldEntityEnum.CONTRIBUTOR + CustomFieldEntityEnum.CONTRIBUTOR, parent=parent_entity ), ) chain = QueryBuilderChain( [ FieldValidationHandler(valid_fields.keys()), UserFilterHandler( + team_id=self.service_provider.client.team_id, + project_id=project_id, service_provider=self.service_provider, entity=CustomFieldEntityEnum.CONTRIBUTOR, + parent=parent_entity, ), ] ) query = chain.handle(filters, EmptyQuery()) if include and "custom_fields" in include: response = self.service_provider.work_management.list_users( - query, include_custom_fields=True + query, + include_custom_fields=True, + parent_entity=parent_entity, + project_id=project_id, ) if not response.ok: raise AppException(response.error) users = response.data custom_fields_list = [user.custom_fields for user in users] serialized_fields = serialize_custom_fields( + self.service_provider.client.team_id, + project_id, self.service_provider, custom_fields_list, - CustomFieldEntityEnum.CONTRIBUTOR, + entity=CustomFieldEntityEnum.CONTRIBUTOR, + parent_entity=parent_entity, ) for users, serialized_custom_fields in zip(users, serialized_fields): users.custom_fields = serialized_custom_fields return response.data - return self.service_provider.work_management.list_users(query).data + return self.service_provider.work_management.list_users( + query, parent_entity=parent_entity, project_id=project_id + ).data def update_user_activity( self, @@ -406,14 +438,18 @@ def list_projects( valid_fields = generate_schema( ProjectFilters.__annotations__, self.service_provider.get_custom_fields_templates( - CustomFieldEntityEnum.PROJECT + CustomFieldEntityEnum.PROJECT, parent=CustomFieldEntityEnum.TEAM ), ) chain = QueryBuilderChain( [ FieldValidationHandler(valid_fields.keys()), ProjectFilterHandler( - self.service_provider, entity=CustomFieldEntityEnum.PROJECT + team_id=self.service_provider.client.team_id, + project_id=None, + service_provider=self.service_provider, + entity=CustomFieldEntityEnum.PROJECT, + parent=CustomFieldEntityEnum.TEAM, ), ] ) @@ -435,7 +471,11 @@ def list_projects( if include_custom_fields: custom_fields_list = [project.custom_fields for project in projects] serialized_fields = serialize_custom_fields( - self.service_provider, custom_fields_list, CustomFieldEntityEnum.PROJECT + self.service_provider.client.team_id, + None, + self.service_provider, + custom_fields_list, + CustomFieldEntityEnum.PROJECT, ) for project, serialized_custom_fields in zip(projects, serialized_fields): project.custom_fields = serialized_custom_fields diff --git a/src/superannotate/lib/infrastructure/query_builder.py b/src/superannotate/lib/infrastructure/query_builder.py index 4473e502b..edb24c5b1 100644 --- a/src/superannotate/lib/infrastructure/query_builder.py +++ b/src/superannotate/lib/infrastructure/query_builder.py @@ -152,23 +152,31 @@ def _handle_special_fields(self, keys: List[str], val): class BaseCustomFieldHandler(AbstractQueryHandler): def __init__( - self, service_provider: BaseServiceProvider, entity: CustomFieldEntityEnum + self, + team_id: int, + project_id: Optional[int], + service_provider: BaseServiceProvider, + entity: CustomFieldEntityEnum, + parent: CustomFieldEntityEnum, ): self._service_provider = service_provider self._entity = entity + self._parent = parent def _handle_custom_field_key(self, key) -> Tuple[str, str, Optional[str]]: for custom_field in sorted( - self._service_provider.list_custom_field_names(entity=self._entity), + self._service_provider.list_custom_field_names( + entity=self._entity, parent=self._parent + ), key=len, reverse=True, ): if custom_field in key: custom_field_id = self._service_provider.get_custom_field_id( - custom_field, entity=self._entity + custom_field, entity=self._entity, parent=self._parent ) component_id = self._service_provider.get_custom_field_component_id( - custom_field_id, entity=self._entity + custom_field_id, entity=self._entity, parent=self._parent ) key = key.replace( custom_field, @@ -209,7 +217,7 @@ def _determine_condition_and_key(keys: List[str]) -> Tuple[OperatorEnum, str]: def _handle_special_fields(self, keys: List[str], val): if keys[0] == "custom_field": component_id = self._service_provider.get_custom_field_component_id( - field_id=int(keys[1]), entity=self._entity + field_id=int(keys[1]), entity=self._entity, parent=self._parent ) if component_id == CustomFieldType.DATE_PICKER.value and val is not None: try: diff --git a/src/superannotate/lib/infrastructure/serviceprovider.py b/src/superannotate/lib/infrastructure/serviceprovider.py index 641111a49..e05be536e 100644 --- a/src/superannotate/lib/infrastructure/serviceprovider.py +++ b/src/superannotate/lib/infrastructure/serviceprovider.py @@ -69,35 +69,50 @@ def __init__(self, client: HttpClient): 5, self.work_management ) - def get_custom_fields_templates(self, entity: CustomFieldEntityEnum): + def get_custom_fields_templates( + self, entity: CustomFieldEntityEnum, parent: CustomFieldEntityEnum + ): return self._cached_work_management_repository.list_templates( - self.client.team_id, entity=entity + self.client.team_id, entity=entity, parent=parent ) - def list_custom_field_names(self, entity: CustomFieldEntityEnum) -> List[str]: + def list_custom_field_names( + self, pk, entity: CustomFieldEntityEnum, parent: CustomFieldEntityEnum + ) -> List[str]: return self._cached_work_management_repository.list_custom_field_names( - self.client.team_id, entity=entity + pk, + entity=entity, + parent=parent, ) def get_custom_field_id( - self, field_name: str, entity: CustomFieldEntityEnum + self, + field_name: str, + entity: CustomFieldEntityEnum, + parent: CustomFieldEntityEnum, ) -> int: return self._cached_work_management_repository.get_custom_field_id( - self.client.team_id, field_name, entity=entity + self.client.team_id, field_name, entity=entity, parent=parent ) def get_custom_field_name( - self, field_id: int, entity: CustomFieldEntityEnum + self, + field_id: int, + entity: CustomFieldEntityEnum, + parent: CustomFieldEntityEnum, ) -> str: return self._cached_work_management_repository.get_custom_field_name( - self.client.team_id, field_id, entity=entity + self.client.team_id, field_id, entity=entity, parent=parent ) def get_custom_field_component_id( - self, field_id: int, entity: CustomFieldEntityEnum + self, + field_id: int, + entity: CustomFieldEntityEnum, + parent: CustomFieldEntityEnum, ) -> str: return self._cached_work_management_repository.get_custom_field_component_id( - self.client.team_id, field_id, entity=entity + self.client.team_id, field_id, entity=entity, parent=parent ) def get_role_id(self, project: entities.ProjectEntity, role_name: str) -> int: diff --git a/src/superannotate/lib/infrastructure/services/work_management.py b/src/superannotate/lib/infrastructure/services/work_management.py index df1f1e39f..937de899b 100644 --- a/src/superannotate/lib/infrastructure/services/work_management.py +++ b/src/superannotate/lib/infrastructure/services/work_management.py @@ -6,6 +6,7 @@ from lib.core.entities import CategoryEntity from lib.core.entities import WorkflowEntity from lib.core.entities.work_managament import WMProjectEntity +from lib.core.entities.work_managament import WMProjectUserEntity from lib.core.entities.work_managament import WMUserEntity from lib.core.enums import CustomFieldEntityEnum from lib.core.exceptions import AppException @@ -60,6 +61,7 @@ class WorkManagementService(BaseWorkManagementService): URL_SET_CUSTOM_ENTITIES = "customentities/{pk}" URL_SEARCH_CUSTOM_ENTITIES = "customentities/search" URL_SEARCH_TEAM_USERS = "teamusers/search" + URL_SEARCH_PROJECT_USERS = "projectusers/search" URL_SEARCH_PROJECTS = "projects/search" URL_RESUME_PAUSE_USER = "teams/editprojectsusers" @@ -259,27 +261,42 @@ def search_projects( ) def list_users( - self, body_query: Query, chunk_size=100, include_custom_fields=False + self, + body_query: Query, + chunk_size=100, + parent_entity: str = "Team", + project_id: int = None, + include_custom_fields=False, ) -> WMUserListResponse: if include_custom_fields: url = self.URL_SEARCH_CUSTOM_ENTITIES else: - url = self.URL_SEARCH_TEAM_USERS + if parent_entity == "Team": + url = self.URL_SEARCH_TEAM_USERS + else: + url = self.URL_SEARCH_PROJECT_USERS + if project_id is None: + user_entity = WMUserEntity + entity_context = self._generate_context(team_id=self.client.team_id) + else: + user_entity = WMProjectUserEntity + entity_context = self._generate_context( + team_id=self.client.team_id, + project_id=project_id, + ) return self.client.jsx_paginate( url=url, method="post", body_query=body_query, query_params={ "entity": "Contributor", - "parentEntity": "Team", + "parentEntity": parent_entity, }, headers={ - "x-sa-entity-context": self._generate_context( - team_id=self.client.team_id - ), + "x-sa-entity-context": entity_context, }, chunk_size=chunk_size, - item_type=WMUserEntity, + item_type=user_entity, ) def create_custom_field_template( diff --git a/src/superannotate/lib/infrastructure/utils.py b/src/superannotate/lib/infrastructure/utils.py index eaa75e80d..25069d747 100644 --- a/src/superannotate/lib/infrastructure/utils.py +++ b/src/superannotate/lib/infrastructure/utils.py @@ -219,6 +219,38 @@ def get(self, key, **kwargs): return self._K_V_map[key] +class ProjectUserCustomFieldCache(CustomFieldCache): + def sync(self, project_id): + response = self.work_management.list_custom_field_templates( + entity=self._entity, + parent_entity=self._parent_entity, + context={"project_id": project_id}, + ) + if not response.ok: + raise AppException(response.error) + custom_fields_name_id_map = { + field["name"]: field["id"] for field in response.data["data"] + } + custom_fields_id_name_map = { + field["id"]: field["name"] for field in response.data["data"] + } + custom_fields_id_component_id_map = { + field["id"]: field["component_id"] for field in response.data["data"] + } + self._K_V_map[project_id] = { + "custom_fields_name_id_map": custom_fields_name_id_map, + "custom_fields_id_name_map": custom_fields_id_name_map, + "custom_fields_id_component_id_map": custom_fields_id_component_id_map, + "templates": response.data["data"], + } + self._update_cache_timestamp(project_id) + + def get(self, key, **kwargs): + if not self._is_cache_valid(key): + self.sync(project_id=key) + return self._K_V_map[key] + + class CachedWorkManagementRepository: def __init__(self, ttl_seconds: int, work_management): self._role_cache = RoleCache(ttl_seconds, work_management) @@ -229,12 +261,18 @@ def __init__(self, ttl_seconds: int, work_management): CustomFieldEntityEnum.PROJECT, CustomFieldEntityEnum.TEAM, ) - self._user_custom_field_cache = CustomFieldCache( + self._team_user_custom_field_cache = CustomFieldCache( ttl_seconds, work_management, CustomFieldEntityEnum.CONTRIBUTOR, CustomFieldEntityEnum.TEAM, ) + self._project_user_custom_field_cache = ProjectUserCustomFieldCache( + ttl_seconds, + work_management, + CustomFieldEntityEnum.CONTRIBUTOR, + CustomFieldEntityEnum.PROJECT, + ) def get_role_id(self, project, role_name: str) -> int: role_data = self._role_cache.get(project.id, project=project) @@ -261,50 +299,79 @@ def get_annotation_status_name(self, project, status_value: int) -> str: raise AppException("Invalid status value provided.") def get_custom_field_id( - self, team_id: int, field_name: str, entity: CustomFieldEntityEnum + self, + team_id: int, + field_name: str, + entity: CustomFieldEntityEnum, + parent: CustomFieldEntityEnum, ) -> int: if entity == CustomFieldEntityEnum.PROJECT: custom_field_data = self._project_custom_field_cache.get(team_id) else: - custom_field_data = self._user_custom_field_cache.get(team_id) + if parent == CustomFieldEntityEnum.TEAM: + custom_field_data = self._team_user_custom_field_cache.get(team_id) + else: + custom_field_data = self._project_user_custom_field_cache.get(team_id) if field_name in custom_field_data["custom_fields_name_id_map"]: return custom_field_data["custom_fields_name_id_map"][field_name] raise AppException("Invalid custom field name provided.") def get_custom_field_name( - self, team_id: int, field_id: int, entity: CustomFieldEntityEnum + self, + team_id: int, + field_id: int, + entity: CustomFieldEntityEnum, + parent: CustomFieldEntityEnum, ) -> str: if entity == CustomFieldEntityEnum.PROJECT: custom_field_data = self._project_custom_field_cache.get(team_id) else: - custom_field_data = self._user_custom_field_cache.get(team_id) + if parent == CustomFieldEntityEnum.TEAM: + custom_field_data = self._team_user_custom_field_cache.get(team_id) + else: + custom_field_data = self._project_user_custom_field_cache.get(team_id) if field_id in custom_field_data["custom_fields_id_name_map"]: return custom_field_data["custom_fields_id_name_map"][field_id] raise AppException("Invalid custom field ID provided.") def get_custom_field_component_id( - self, team_id: int, field_id: int, entity: CustomFieldEntityEnum + self, + team_id: int, + field_id: int, + entity: CustomFieldEntityEnum, + parent: CustomFieldEntityEnum, ) -> str: if entity == CustomFieldEntityEnum.PROJECT: custom_field_data = self._project_custom_field_cache.get(team_id) else: - custom_field_data = self._user_custom_field_cache.get(team_id) + if parent == CustomFieldEntityEnum.TEAM: + custom_field_data = self._team_user_custom_field_cache.get(team_id) + else: + custom_field_data = self._project_user_custom_field_cache.get(team_id) if field_id in custom_field_data["custom_fields_id_component_id_map"]: return custom_field_data["custom_fields_id_component_id_map"][field_id] raise AppException("Invalid custom field ID provided.") def list_custom_field_names( - self, team_id: int, entity: CustomFieldEntityEnum + self, pk: int, entity: CustomFieldEntityEnum, parent: CustomFieldEntityEnum ) -> list: if entity == CustomFieldEntityEnum.PROJECT: - custom_field_data = self._project_custom_field_cache.get(team_id) + custom_field_data = self._project_custom_field_cache.get(pk) else: - custom_field_data = self._user_custom_field_cache.get(team_id) + if parent == CustomFieldEntityEnum.TEAM: + custom_field_data = self._team_user_custom_field_cache.get(pk) + else: + custom_field_data = self._project_user_custom_field_cache.get(pk) return list(custom_field_data["custom_fields_name_id_map"].keys()) - def list_templates(self, team_id: int, entity: CustomFieldEntityEnum): + def list_templates( + self, pk: int, entity: CustomFieldEntityEnum, parent: CustomFieldEntityEnum + ): if entity == CustomFieldEntityEnum.PROJECT: - custom_field_data = self._project_custom_field_cache.get(team_id) - else: - custom_field_data = self._user_custom_field_cache.get(team_id) - return custom_field_data["templates"] + return self._project_custom_field_cache.get(pk)["templates"] + elif entity == CustomFieldEntityEnum.CONTRIBUTOR: + if parent == CustomFieldEntityEnum.TEAM: + return self._team_user_custom_field_cache.get(pk)["templates"] + else: + return self._project_user_custom_field_cache.get(pk)["templates"] + raise AppException("Invalid entity provided.") diff --git a/tests/integration/work_management/test_user_custom_fields.py b/tests/integration/work_management/test_user_custom_fields.py index f1eea65bf..4955433ae 100644 --- a/tests/integration/work_management/test_user_custom_fields.py +++ b/tests/integration/work_management/test_user_custom_fields.py @@ -5,8 +5,10 @@ from lib.core.exceptions import AppException from src.superannotate import SAClient from src.superannotate.lib.core.enums import CustomFieldEntityEnum +from tests.integration.base import BaseTestCase from tests.integration.work_management.data_set import CUSTOM_FIELD_PAYLOADS + sa = SAClient() @@ -231,3 +233,19 @@ def test_set_user_custom_field_validation(self): error_template_select.format(type="str", options="option1, option2"), ): sa.set_user_custom_field(scapegoat["email"], "SDK_test_single_select", 123) + + +class TestUserProjectCustomFields(BaseTestCase): + PROJECT_NAME = "TestUserProjectCustomFields" + PROJECT_TYPE = "Multimodal" + PROJECT_DESCRIPTION = "Multimodal" + + def test_project_custom_fields(self): + scapegoat = sa.list_users(role="contributor")[0] + sa.add_contributors_to_project( + self.PROJECT_NAME, [scapegoat["email"]], role="QA" + ) + users = sa.list_users(project=self.PROJECT_NAME) + assert users[0]["role"] == "QA" + users = sa.list_users(project=self.PROJECT_NAME, include=["custom_fields"]) + assert users[0]["role"] == "QA"