Skip to content
This repository has been archived by the owner on Nov 14, 2023. It is now read-only.

Commit

Permalink
Fix: skeletons cannot be added to a task or project (cvat-ai#5813)
Browse files Browse the repository at this point in the history
  • Loading branch information
yasakova-anastasia authored and mikhail-treskin committed Jul 1, 2023
1 parent de880b5 commit 039bfe1
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 15 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ non-ascii paths while adding files from "Connected file share" (issue #4428)
(<https://github.com/opencv/cvat/issues/4365>)
- Queries via the low-level API using the `multipart/form-data` Content-Type with string fields
(<https://github.com/opencv/cvat/pull/5479>)
- Skeletons cannot be added to a task or project (<https://github.com/opencv/cvat/pull/5813>)

### Security
- `Project.import_dataset` not waiting for completion correctly
Expand Down
31 changes: 16 additions & 15 deletions cvat/apps/engine/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,8 @@ def validate(self, attrs):
def update_label(
cls,
validated_data: Dict[str, Any],
svg: str,
sublabels: Iterable[Dict[str, Any]],
*,
parent_instance: Union[models.Project, models.Task],
parent_label: Optional[models.Label] = None
Expand Down Expand Up @@ -299,6 +301,19 @@ def update_label(
raise exceptions.ValidationError(str(exc)) from exc
logger.info("New {} label was created".format(db_label.name))

cls.update_labels(sublabels, parent_instance=parent_instance, parent_label=db_label)

if db_label.type == str(models.LabelType.SKELETON):
for db_sublabel in list(db_label.sublabels.all()):
svg = svg.replace(
f'data-label-name="{db_sublabel.name}"',
f'data-label-id="{db_sublabel.id}"'
)
db_skeleton = models.Skeleton.objects.create(root=db_label, svg=svg)
logger.info(
f'label:update Skeleton id:{db_skeleton.id} for label_id:{db_label.id}'
)

if validated_data.get('deleted'):
assert validated_data['id'] # must be checked in the validate()
db_label.delete()
Expand Down Expand Up @@ -400,7 +415,7 @@ def update_labels(cls,
for label in labels:
sublabels = label.pop('sublabels', [])
svg = label.pop('svg', '')
db_label = cls.update_label(label,
db_label = cls.update_label(label, svg, sublabels,
parent_instance=parent_instance, parent_label=parent_label
)
if db_label:
Expand All @@ -414,20 +429,6 @@ def update_labels(cls,
f'sublabels:{sublabels}, parent_label:{parent_label}'
)

if not label.get('deleted'):
cls.update_labels(sublabels, parent_instance=parent_instance, parent_label=db_label)

if label.get('id') is None and db_label.type == str(models.LabelType.SKELETON):
for db_sublabel in list(db_label.sublabels.all()):
svg = svg.replace(
f'data-label-name="{db_sublabel.name}"',
f'data-label-id="{db_sublabel.id}"'
)
db_skeleton = models.Skeleton.objects.create(root=db_label, svg=svg)
logger.info(
f'label:update Skeleton id:{db_skeleton.id} for label_id:{db_label.id}'
)

@classmethod
def _get_parent_info(cls, parent_instance: Union[models.Project, models.Task]):
parent_info = {}
Expand Down
22 changes: 22 additions & 0 deletions tests/python/rest_api/test_projects.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,6 +829,28 @@ def test_project_staff_org_members_can_add_label(
assert response.status_code == HTTPStatus.OK
assert response.json()["labels"]["count"] == project["labels"]["count"] + 1

def test_admin_can_add_skeleton(self, projects, admin_user):
project = list(projects)[0]
new_skeleton = {
"name": "skeleton1",
"type": "skeleton",
"sublabels": [
{
"name": "1",
"type": "points",
}
],
"svg": '<circle r="1.5" stroke="black" fill="#b3b3b3" cx="48.794559478759766" '
'cy="36.98698806762695" stroke-width="0.1" data-type="element node" '
'data-element-id="1" data-node-id="1" data-label-name="597501"></circle>',
}

response = patch_method(
admin_user, f'/projects/{project["id"]}', {"labels": [new_skeleton]}
)
assert response.status_code == HTTPStatus.OK
assert response.json()["labels"]["count"] == project["labels"]["count"] + 1


@pytest.mark.usefixtures("restore_db_per_class")
class TestGetProjectPreview:
Expand Down
20 changes: 20 additions & 0 deletions tests/python/rest_api/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1130,6 +1130,26 @@ def test_task_staff_org_members_can_add_label(
assert response.status_code == HTTPStatus.OK
assert response.json()["labels"]["count"] == task["labels"]["count"] + 1

def test_admin_can_add_skeleton(self, tasks, admin_user):
task = list(tasks)[0]
new_skeleton = {
"name": "skeleton1",
"type": "skeleton",
"sublabels": [
{
"name": "1",
"type": "points",
}
],
"svg": '<circle r="1.5" stroke="black" fill="#b3b3b3" cx="48.794559478759766" '
'cy="36.98698806762695" stroke-width="0.1" data-type="element node" '
'data-element-id="1" data-node-id="1" data-label-name="597501"></circle>',
}

response = patch_method(admin_user, f'/tasks/{task["id"]}', {"labels": [new_skeleton]})
assert response.status_code == HTTPStatus.OK
assert response.json()["labels"]["count"] == task["labels"]["count"] + 1


@pytest.mark.usefixtures("restore_db_per_function")
@pytest.mark.usefixtures("restore_cvat_data")
Expand Down

0 comments on commit 039bfe1

Please sign in to comment.