Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 25 additions & 3 deletions src/superannotate/lib/app/interface/sdk_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,9 +372,18 @@ def set_user_custom_field(
parent_entity=CustomFieldEntityEnum.TEAM,
)

def list_users(self, *, include: List[Literal["custom_fields"]] = None, **filters):
def list_users(
self,
*,
project: Union[int, str] = None,
include: List[Literal["custom_fields"]] = None,
**filters,
):
"""
Search users by filtering criteria
:param project: Project name or ID, if provided, results will be for project-level,
otherwise results will be for team level.
:type project: str or int

:param include: Specifies additional fields to be included in the response.

Expand Down Expand Up @@ -454,9 +463,22 @@ def list_users(self, *, include: List[Literal["custom_fields"]] = None, **filter
}
]
"""
return BaseSerializer.serialize_iterable(
self.controller.work_management.list_users(include=include, **filters)
if project is not None:
if isinstance(project, int):
project = self.controller.get_project_by_id(project)
else:
project = self.controller.get_project(project)
response = BaseSerializer.serialize_iterable(
self.controller.work_management.list_users(
project=project, include=include, **filters
)
)
if project:
for user in response:
user["role"] = self.controller.service_provider.get_role_name(
project, user["role"]
)
return response

def pause_user_activity(
self, pk: Union[int, str], projects: Union[List[int], List[str], Literal["*"]]
Expand Down
30 changes: 30 additions & 0 deletions src/superannotate/lib/core/entities/work_managament.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,3 +119,33 @@ def json(self, **kwargs):
if "exclude" not in kwargs:
kwargs["exclude"] = {"custom_fields"}
return super().json(**kwargs)


class WMProjectUserEntity(TimedBaseModel):
id: Optional[int]
team_id: Optional[int]
role: int
email: Optional[str]
state: Optional[WMUserStateEnum]
custom_fields: Optional[dict] = Field(dict(), alias="customField")

class Config:
extra = Extra.ignore
use_enum_names = True

json_encoders = {
Enum: lambda v: v.value,
datetime.date: lambda v: v.isoformat(),
datetime.datetime: lambda v: v.isoformat(),
}

@validator("custom_fields")
def custom_fields_transformer(cls, v):
if v and "custom_field_values" in v:
return v.get("custom_field_values", {})
return {}

def json(self, **kwargs):
if "exclude" not in kwargs:
kwargs["exclude"] = {"custom_fields"}
return super().json(**kwargs)
32 changes: 27 additions & 5 deletions src/superannotate/lib/core/serviceproviders.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,12 @@ def create_project_categories(

@abstractmethod
def list_users(
self, body_query: Query, chunk_size=100, include_custom_fields=False
self,
body_query: Query,
parent_entity: str = "Team",
chunk_size=100,
project_id: int = None,
include_custom_fields=False,
) -> WMUserListResponse:
raise NotImplementedError

Expand Down Expand Up @@ -812,23 +817,40 @@ def invite_contributors(
raise NotImplementedError

@abstractmethod
def list_custom_field_names(self, entity: CustomFieldEntityEnum) -> List[str]:
def list_custom_field_names(
self, pk, entity: CustomFieldEntityEnum, parent: CustomFieldEntityEnum
) -> List[str]:
raise NotImplementedError

@abstractmethod
def get_custom_field_id(
self, field_name: str, entity: CustomFieldEntityEnum
self,
field_name: str,
entity: CustomFieldEntityEnum,
parent: CustomFieldEntityEnum,
) -> int:
raise NotImplementedError

@abstractmethod
def get_custom_field_name(
self, field_id: int, entity: CustomFieldEntityEnum
self,
field_id: int,
entity: CustomFieldEntityEnum,
parent: CustomFieldEntityEnum,
) -> str:
raise NotImplementedError

@abstractmethod
def get_custom_field_component_id(
self, field_id: int, entity: CustomFieldEntityEnum
self,
field_id: int,
entity: CustomFieldEntityEnum,
parent: CustomFieldEntityEnum,
) -> str:
raise NotImplementedError

@abstractmethod
def get_custom_fields_templates(
self, entity: CustomFieldEntityEnum, parent: CustomFieldEntityEnum
):
raise NotImplementedError
12 changes: 9 additions & 3 deletions src/superannotate/lib/core/usecases/projects.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,9 @@ def execute(self):
project.users = []
if self._include_custom_fields:
custom_fields_names = self._service_provider.list_custom_field_names(
entity=CustomFieldEntityEnum.PROJECT
self._project.team_id,
entity=CustomFieldEntityEnum.PROJECT,
parent=CustomFieldEntityEnum.TEAM,
)
if custom_fields_names:
project_custom_fields = (
Expand All @@ -171,7 +173,9 @@ def execute(self):
custom_fields_name_value_map = {}
for name in custom_fields_names:
field_id = self._service_provider.get_custom_field_id(
name, entity=CustomFieldEntityEnum.PROJECT
name,
entity=CustomFieldEntityEnum.PROJECT,
parent=CustomFieldEntityEnum.TEAM,
)
field_value = (
custom_fields_id_value_map[str(field_id)]
Expand All @@ -180,7 +184,9 @@ def execute(self):
)
# timestamp: convert milliseconds to seconds
component_id = self._service_provider.get_custom_field_component_id(
field_id, entity=CustomFieldEntityEnum.PROJECT
field_id,
entity=CustomFieldEntityEnum.PROJECT,
parent=CustomFieldEntityEnum.TEAM,
)
if (
field_value
Expand Down
68 changes: 54 additions & 14 deletions src/superannotate/lib/infrastructure/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,22 @@ def build_condition(**kwargs) -> Condition:


def serialize_custom_fields(
service_provider: ServiceProvider, data: List[dict], entity: CustomFieldEntityEnum
team_id: int,
project_id: int,
service_provider: ServiceProvider,
data: List[dict],
entity: CustomFieldEntityEnum,
parent_entity: CustomFieldEntityEnum,
) -> List[dict]:
existing_custom_fields = service_provider.list_custom_field_names(entity)
pk = (
project_id
if entity == CustomFieldEntityEnum.PROJECT
else (team_id if parent_entity == CustomFieldEntityEnum.TEAM else project_id)
)

existing_custom_fields = service_provider.list_custom_field_names(
pk, entity, parent=parent_entity
)
for i in range(len(data)):
if not data[i]:
data[i] = {}
Expand All @@ -85,7 +98,7 @@ def serialize_custom_fields(
field_id = int(custom_field_name)
try:
component_id = service_provider.get_custom_field_component_id(
field_id, entity=entity
field_id, entity=entity, parent=parent_entity
)
except AppException:
# The component template can be deleted, but not from the entity, so it will be skipped.
Expand All @@ -95,7 +108,7 @@ def serialize_custom_fields(
field_value /= 1000 # Convert timestamp

new_field_name = service_provider.get_custom_field_name(
field_id, entity=entity
field_id, entity=entity, parent=parent_entity
)
updated_fields[new_field_name] = field_value

Expand Down Expand Up @@ -139,10 +152,10 @@ def set_custom_field_value(
if entity == CustomFieldEntityEnum.PROJECT:
_context["project_id"] = entity_id
template_id = self.service_provider.get_custom_field_id(
field_name, entity=entity
field_name, entity=entity, parent=parent_entity
)
component_id = self.service_provider.get_custom_field_component_id(
template_id, entity=entity
template_id, entity=entity, parent=parent_entity
)
# timestamp: convert seconds to milliseconds
if component_id == CustomFieldType.DATE_PICKER.value and value is not None:
Expand All @@ -159,40 +172,59 @@ def set_custom_field_value(
context=_context,
)

def list_users(self, include: List[Literal["custom_fields"]] = None, **filters):
def list_users(
self, include: List[Literal["custom_fields"]] = None, project=None, **filters
):
if project:
parent_entity = CustomFieldEntityEnum.PROJECT
project_id = project.id
else:
parent_entity = CustomFieldEntityEnum.TEAM
project_id = None
valid_fields = generate_schema(
UserFilters.__annotations__,
self.service_provider.get_custom_fields_templates(
CustomFieldEntityEnum.CONTRIBUTOR
CustomFieldEntityEnum.CONTRIBUTOR, parent=parent_entity
),
)
chain = QueryBuilderChain(
[
FieldValidationHandler(valid_fields.keys()),
UserFilterHandler(
team_id=self.service_provider.client.team_id,
project_id=project_id,
service_provider=self.service_provider,
entity=CustomFieldEntityEnum.CONTRIBUTOR,
parent=parent_entity,
),
]
)
query = chain.handle(filters, EmptyQuery())
if include and "custom_fields" in include:
response = self.service_provider.work_management.list_users(
query, include_custom_fields=True
query,
include_custom_fields=True,
parent_entity=parent_entity,
project_id=project_id,
)
if not response.ok:
raise AppException(response.error)
users = response.data
custom_fields_list = [user.custom_fields for user in users]
serialized_fields = serialize_custom_fields(
self.service_provider.client.team_id,
project_id,
self.service_provider,
custom_fields_list,
CustomFieldEntityEnum.CONTRIBUTOR,
entity=CustomFieldEntityEnum.CONTRIBUTOR,
parent_entity=parent_entity,
)
for users, serialized_custom_fields in zip(users, serialized_fields):
users.custom_fields = serialized_custom_fields
return response.data
return self.service_provider.work_management.list_users(query).data
return self.service_provider.work_management.list_users(
query, parent_entity=parent_entity, project_id=project_id
).data

def update_user_activity(
self,
Expand Down Expand Up @@ -406,14 +438,18 @@ def list_projects(
valid_fields = generate_schema(
ProjectFilters.__annotations__,
self.service_provider.get_custom_fields_templates(
CustomFieldEntityEnum.PROJECT
CustomFieldEntityEnum.PROJECT, parent=CustomFieldEntityEnum.TEAM
),
)
chain = QueryBuilderChain(
[
FieldValidationHandler(valid_fields.keys()),
ProjectFilterHandler(
self.service_provider, entity=CustomFieldEntityEnum.PROJECT
team_id=self.service_provider.client.team_id,
project_id=None,
service_provider=self.service_provider,
entity=CustomFieldEntityEnum.PROJECT,
parent=CustomFieldEntityEnum.TEAM,
),
]
)
Expand All @@ -435,7 +471,11 @@ def list_projects(
if include_custom_fields:
custom_fields_list = [project.custom_fields for project in projects]
serialized_fields = serialize_custom_fields(
self.service_provider, custom_fields_list, CustomFieldEntityEnum.PROJECT
self.service_provider.client.team_id,
None,
self.service_provider,
custom_fields_list,
CustomFieldEntityEnum.PROJECT,
)
for project, serialized_custom_fields in zip(projects, serialized_fields):
project.custom_fields = serialized_custom_fields
Expand Down
18 changes: 13 additions & 5 deletions src/superannotate/lib/infrastructure/query_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,23 +162,31 @@ def _handle_special_fields(self, keys: List[str], val):

class BaseCustomFieldHandler(AbstractQueryHandler):
def __init__(
self, service_provider: BaseServiceProvider, entity: CustomFieldEntityEnum
self,
team_id: int,
project_id: Optional[int],
service_provider: BaseServiceProvider,
entity: CustomFieldEntityEnum,
parent: CustomFieldEntityEnum,
):
self._service_provider = service_provider
self._entity = entity
self._parent = parent

def _handle_custom_field_key(self, key) -> Tuple[str, str, Optional[str]]:
for custom_field in sorted(
self._service_provider.list_custom_field_names(entity=self._entity),
self._service_provider.list_custom_field_names(
entity=self._entity, parent=self._parent
),
key=len,
reverse=True,
):
if custom_field in key:
custom_field_id = self._service_provider.get_custom_field_id(
custom_field, entity=self._entity
custom_field, entity=self._entity, parent=self._parent
)
component_id = self._service_provider.get_custom_field_component_id(
custom_field_id, entity=self._entity
custom_field_id, entity=self._entity, parent=self._parent
)
key = key.replace(
custom_field,
Expand Down Expand Up @@ -219,7 +227,7 @@ def _determine_condition_and_key(keys: List[str]) -> Tuple[OperatorEnum, str]:
def _handle_special_fields(self, keys: List[str], val):
if keys[0] == "custom_field":
component_id = self._service_provider.get_custom_field_component_id(
field_id=int(keys[1]), entity=self._entity
field_id=int(keys[1]), entity=self._entity, parent=self._parent
)
if component_id == CustomFieldType.DATE_PICKER.value and val is not None:
try:
Expand Down
Loading