Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@ minversion = 3.7
log_cli=true
python_files = test_*.py
;pytest_plugins = ['pytest_profiling']
;addopts = -n 6 --dist loadscope
addopts = -n 6 --dist loadscope
12 changes: 10 additions & 2 deletions src/superannotate/lib/core/serviceproviders.py
Original file line number Diff line number Diff line change
Expand Up @@ -851,13 +851,17 @@ def invite_contributors(

@abstractmethod
def list_custom_field_names(
self, pk, entity: CustomFieldEntityEnum, parent: CustomFieldEntityEnum
self,
context: dict,
entity: CustomFieldEntityEnum,
parent: CustomFieldEntityEnum,
) -> List[str]:
raise NotImplementedError

@abstractmethod
def get_custom_field_id(
self,
context: dict,
field_name: str,
entity: CustomFieldEntityEnum,
parent: CustomFieldEntityEnum,
Expand All @@ -867,6 +871,7 @@ def get_custom_field_id(
@abstractmethod
def get_custom_field_name(
self,
context: dict,
field_id: int,
entity: CustomFieldEntityEnum,
parent: CustomFieldEntityEnum,
Expand All @@ -884,6 +889,9 @@ def get_custom_field_component_id(

@abstractmethod
def get_custom_fields_templates(
self, entity: CustomFieldEntityEnum, parent: CustomFieldEntityEnum
self,
context: dict,
entity: CustomFieldEntityEnum,
parent: CustomFieldEntityEnum,
):
raise NotImplementedError
4 changes: 3 additions & 1 deletion src/superannotate/lib/core/usecases/projects.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,9 @@ def execute(self):
else:
project.users = []
if self._include_custom_fields:
context = {"team_id": self._project.team_id}
custom_fields_names = self._service_provider.list_custom_field_names(
self._project.team_id,
context,
entity=CustomFieldEntityEnum.PROJECT,
parent=CustomFieldEntityEnum.TEAM,
)
Expand All @@ -173,6 +174,7 @@ def execute(self):
custom_fields_name_value_map = {}
for name in custom_fields_names:
field_id = self._service_provider.get_custom_field_id(
context,
name,
entity=CustomFieldEntityEnum.PROJECT,
parent=CustomFieldEntityEnum.TEAM,
Expand Down
24 changes: 12 additions & 12 deletions src/superannotate/lib/infrastructure/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,10 @@ def serialize_custom_fields(
entity: CustomFieldEntityEnum,
parent_entity: CustomFieldEntityEnum,
) -> List[dict]:
pk = (
project_id
if entity == CustomFieldEntityEnum.PROJECT
else (team_id if parent_entity == CustomFieldEntityEnum.TEAM else project_id)
)
context = {"team_id": team_id, "project_id": project_id}

existing_custom_fields = service_provider.list_custom_field_names(
pk, entity, parent=parent_entity
context, entity, parent=parent_entity
)
for i in range(len(data)):
if not data[i]:
Expand All @@ -111,7 +107,7 @@ def serialize_custom_fields(
field_value /= 1000 # Convert timestamp

new_field_name = service_provider.get_custom_field_name(
field_id, entity=entity, parent=parent_entity
context, field_id, entity=entity, parent=parent_entity
)
updated_fields[new_field_name] = field_value

Expand Down Expand Up @@ -151,11 +147,12 @@ def set_custom_field_value(
field_name: str,
value: Any,
):
_context = {}
_context = {"team_id": self.service_provider.client.team_id}
if entity == CustomFieldEntityEnum.PROJECT:
_context["project_id"] = entity_id

template_id = self.service_provider.get_custom_field_id(
field_name, entity=entity, parent=parent_entity
_context, field_name, entity=entity, parent=parent_entity
)
component_id = self.service_provider.get_custom_field_component_id(
template_id, entity=entity, parent=parent_entity
Expand All @@ -178,16 +175,17 @@ def set_custom_field_value(
def list_users(
self, include: List[Literal["custom_fields"]] = None, project=None, **filters
):
context = {"team_id": self.service_provider.client.team_id}
if project:
parent_entity = CustomFieldEntityEnum.PROJECT
project_id = project.id
project_id = context["project_id"] = project.id
else:
parent_entity = CustomFieldEntityEnum.TEAM
project_id = None
valid_fields = generate_schema(
UserFilters.__annotations__,
self.service_provider.get_custom_fields_templates(
CustomFieldEntityEnum.CONTRIBUTOR, parent=parent_entity
context, CustomFieldEntityEnum.CONTRIBUTOR, parent=parent_entity
),
)
chain = QueryBuilderChain(
Expand Down Expand Up @@ -574,7 +572,9 @@ def list_projects(
valid_fields = generate_schema(
ProjectFilters.__annotations__,
self.service_provider.get_custom_fields_templates(
CustomFieldEntityEnum.PROJECT, parent=CustomFieldEntityEnum.TEAM
{"team_id": self.service_provider.client.team_id},
CustomFieldEntityEnum.PROJECT,
parent=CustomFieldEntityEnum.TEAM,
),
)
chain = QueryBuilderChain(
Expand Down
13 changes: 3 additions & 10 deletions src/superannotate/lib/infrastructure/query_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,25 +175,18 @@ def __init__(
self._team_id = team_id
self._project_id = project_id

@property
def pk(self):
if self._entity == CustomFieldEntityEnum.PROJECT:
return self._project_id
if self._parent == CustomFieldEntityEnum.TEAM:
return self._team_id
return self._project_id

def _handle_custom_field_key(self, key) -> Tuple[str, str, Optional[str]]:
context = {"team_id": self._team_id, "project_id": self._project_id}
for custom_field in sorted(
self._service_provider.list_custom_field_names(
self.pk, self._entity, parent=self._parent
context, self._entity, parent=self._parent
),
key=len,
reverse=True,
):
if custom_field in key:
custom_field_id = self._service_provider.get_custom_field_id(
custom_field, entity=self._entity, parent=self._parent
context, custom_field, entity=self._entity, parent=self._parent
)
component_id = self._service_provider.get_custom_field_component_id(
custom_field_id, entity=self._entity, parent=self._parent
Expand Down
23 changes: 16 additions & 7 deletions src/superannotate/lib/infrastructure/serviceprovider.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from lib.infrastructure.services.telemetry_scoring import TelemetryScoringService
from lib.infrastructure.services.work_management import WorkManagementService
from lib.infrastructure.utils import CachedWorkManagementRepository
from lib.infrastructure.utils import EntityContext


class ServiceProvider(BaseServiceProvider):
Expand Down Expand Up @@ -72,17 +73,23 @@ def __init__(self, client: HttpClient):
)

def get_custom_fields_templates(
self, entity: CustomFieldEntityEnum, parent: CustomFieldEntityEnum
self,
context: EntityContext,
entity: CustomFieldEntityEnum,
parent: CustomFieldEntityEnum,
):
return self._cached_work_management_repository.list_templates(
self.client.team_id, entity=entity, parent=parent
context, entity=entity, parent=parent
)

def list_custom_field_names(
self, pk, entity: CustomFieldEntityEnum, parent: CustomFieldEntityEnum
self,
context: EntityContext,
entity: CustomFieldEntityEnum,
parent: CustomFieldEntityEnum,
) -> List[str]:
return self._cached_work_management_repository.list_custom_field_names(
pk,
context,
entity=entity,
parent=parent,
)
Expand All @@ -96,22 +103,24 @@ def get_category_id(

def get_custom_field_id(
self,
context: EntityContext,
field_name: str,
entity: CustomFieldEntityEnum,
parent: CustomFieldEntityEnum,
) -> int:
return self._cached_work_management_repository.get_custom_field_id(
self.client.team_id, field_name, entity=entity, parent=parent
context, field_name, entity=entity, parent=parent
)

def get_custom_field_name(
self,
context: EntityContext,
field_id: int,
entity: CustomFieldEntityEnum,
parent: CustomFieldEntityEnum,
) -> str:
return self._cached_work_management_repository.get_custom_field_name(
self.client.team_id, field_id, entity=entity, parent=parent
context, field_id, entity=entity, parent=parent
)

def get_custom_field_component_id(
Expand All @@ -121,7 +130,7 @@ def get_custom_field_component_id(
parent: CustomFieldEntityEnum,
) -> str:
return self._cached_work_management_repository.get_custom_field_component_id(
self.client.team_id, field_id, entity=entity, parent=parent
{"team_id": self.client.team_id}, field_id, entity=entity, parent=parent
)

def get_role_id(self, project: entities.ProjectEntity, role_name: str) -> int:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -375,9 +375,7 @@ def set_custom_field_value(
url=self.URL_SET_CUSTOM_ENTITIES.format(pk=entity_id),
method="patch",
headers={
"x-sa-entity-context": self._generate_context(
team_id=self.client.team_id, **context
),
"x-sa-entity-context": self._generate_context(**context),
},
data={"customField": {"custom_field_values": {template_id: data}}},
params={
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,4 +174,4 @@ def _store_annotation(path, annotation: dict, callback: Callable = None):
def _process_data(self, data):
if data and self._map_function:
return self._map_function(data)
return data
return data
72 changes: 52 additions & 20 deletions src/superannotate/lib/infrastructure/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import logging
import time
import typing
from abc import ABC
from abc import abstractmethod
from functools import wraps
Expand All @@ -24,6 +25,11 @@
logger = logging.getLogger("sa")


class EntityContext(typing.TypedDict, total=False):
team_id: int
project_id: Optional[int]


def divide_to_chunks(it, size):
it = iter(it)
return iter(lambda: tuple(islice(it, size)), ())
Expand Down Expand Up @@ -324,78 +330,104 @@ def get_annotation_status_name(self, project, status_value: int) -> str:

def get_custom_field_id(
self,
team_id: int,
context: EntityContext,
field_name: str,
entity: CustomFieldEntityEnum,
parent: CustomFieldEntityEnum,
) -> int:
if entity == CustomFieldEntityEnum.PROJECT:
custom_field_data = self._project_custom_field_cache.get(team_id)
custom_field_data = self._project_custom_field_cache.get(context["team_id"])
else:
if parent == CustomFieldEntityEnum.TEAM:
custom_field_data = self._team_user_custom_field_cache.get(team_id)
custom_field_data = self._team_user_custom_field_cache.get(
context["team_id"]
)
else:
custom_field_data = self._project_user_custom_field_cache.get(team_id)
custom_field_data = self._project_user_custom_field_cache.get(
context["project_id"]
)
if field_name in custom_field_data["custom_fields_name_id_map"]:
return custom_field_data["custom_fields_name_id_map"][field_name]
raise AppException("Invalid custom field name provided.")

def get_custom_field_name(
self,
team_id: int,
context: EntityContext,
field_id: int,
entity: CustomFieldEntityEnum,
parent: CustomFieldEntityEnum,
) -> str:
if entity == CustomFieldEntityEnum.PROJECT:
custom_field_data = self._project_custom_field_cache.get(team_id)
custom_field_data = self._project_custom_field_cache.get(context["team_id"])
else:
if parent == CustomFieldEntityEnum.TEAM:
custom_field_data = self._team_user_custom_field_cache.get(team_id)
custom_field_data = self._team_user_custom_field_cache.get(
context["team_id"]
)
else:
custom_field_data = self._project_user_custom_field_cache.get(team_id)
custom_field_data = self._project_user_custom_field_cache.get(
context["project_id"]
)
if field_id in custom_field_data["custom_fields_id_name_map"]:
return custom_field_data["custom_fields_id_name_map"][field_id]
raise AppException("Invalid custom field ID provided.")

def get_custom_field_component_id(
self,
team_id: int,
context: EntityContext,
field_id: int,
entity: CustomFieldEntityEnum,
parent: CustomFieldEntityEnum,
) -> str:
if entity == CustomFieldEntityEnum.PROJECT:
custom_field_data = self._project_custom_field_cache.get(team_id)
custom_field_data = self._project_custom_field_cache.get(context["team_id"])
else:
if parent == CustomFieldEntityEnum.TEAM:
custom_field_data = self._team_user_custom_field_cache.get(team_id)
custom_field_data = self._team_user_custom_field_cache.get(
context["team_id"]
)
else:
custom_field_data = self._project_user_custom_field_cache.get(team_id)
custom_field_data = self._project_user_custom_field_cache.get(
context["project_id"]
)
if field_id in custom_field_data["custom_fields_id_component_id_map"]:
return custom_field_data["custom_fields_id_component_id_map"][field_id]
raise AppException("Invalid custom field ID provided.")

def list_custom_field_names(
self, pk: int, entity: CustomFieldEntityEnum, parent: CustomFieldEntityEnum
self,
context: EntityContext,
entity: CustomFieldEntityEnum,
parent: CustomFieldEntityEnum,
) -> list:
if entity == CustomFieldEntityEnum.PROJECT:
custom_field_data = self._project_custom_field_cache.get(pk)
custom_field_data = self._project_custom_field_cache.get(context["team_id"])
else:
if parent == CustomFieldEntityEnum.TEAM:
custom_field_data = self._team_user_custom_field_cache.get(pk)
custom_field_data = self._team_user_custom_field_cache.get(
context["team_id"]
)
else:
custom_field_data = self._project_user_custom_field_cache.get(pk)
custom_field_data = self._project_user_custom_field_cache.get(
context["project_id"]
)
return list(custom_field_data["custom_fields_name_id_map"].keys())

def list_templates(
self, pk: int, entity: CustomFieldEntityEnum, parent: CustomFieldEntityEnum
self,
context: EntityContext,
entity: CustomFieldEntityEnum,
parent: CustomFieldEntityEnum,
):
if entity == CustomFieldEntityEnum.PROJECT:
return self._project_custom_field_cache.get(pk)["templates"]
return self._project_custom_field_cache.get(context["team_id"])["templates"]
elif entity == CustomFieldEntityEnum.CONTRIBUTOR:
if parent == CustomFieldEntityEnum.TEAM:
return self._team_user_custom_field_cache.get(pk)["templates"]
return self._team_user_custom_field_cache.get(context["team_id"])[
"templates"
]
else:
return self._project_user_custom_field_cache.get(pk)["templates"]
return self._project_user_custom_field_cache.get(context["project_id"])[
"templates"
]
raise AppException("Invalid entity provided.")
Loading