From f03b2bac4b4b6c49967769e1f366271763449c18 Mon Sep 17 00:00:00 2001 From: Vaghinak Basentsyan Date: Fri, 16 May 2025 17:32:11 +0400 Subject: [PATCH] Fix connections validaiton --- src/superannotate/__init__.py | 2 +- src/superannotate/lib/core/usecases/annotations.py | 5 ++--- src/superannotate/lib/core/usecases/projects.py | 11 ++++++++--- tests/integration/steps/test_steps.py | 10 ++-------- 4 files changed, 13 insertions(+), 15 deletions(-) diff --git a/src/superannotate/__init__.py b/src/superannotate/__init__.py index 79f6f695..ae8804fe 100644 --- a/src/superannotate/__init__.py +++ b/src/superannotate/__init__.py @@ -3,7 +3,7 @@ import sys -__version__ = "4.4.36dev1" +__version__ = "4.4.35dev2" os.environ.update({"sa_version": __version__}) diff --git a/src/superannotate/lib/core/usecases/annotations.py b/src/superannotate/lib/core/usecases/annotations.py index 73afab79..950c510a 100644 --- a/src/superannotate/lib/core/usecases/annotations.py +++ b/src/superannotate/lib/core/usecases/annotations.py @@ -2101,9 +2101,8 @@ def execute(self): if categorization_enabled: item_id_category_map = {} for item_name in uploaded_annotations: - category = ( - name_annotation_map[item_name]["metadata"] - .get("item_category", None) + category = name_annotation_map[item_name]["metadata"].get( + "item_category", None ) if category: item_id_category_map[name_item_map[item_name].id] = category diff --git a/src/superannotate/lib/core/usecases/projects.py b/src/superannotate/lib/core/usecases/projects.py index 3eafee61..4455d242 100644 --- a/src/superannotate/lib/core/usecases/projects.py +++ b/src/superannotate/lib/core/usecases/projects.py @@ -1,5 +1,6 @@ import decimal import logging +import math from collections import defaultdict from typing import List @@ -608,10 +609,14 @@ def validate_project_type(self): def validate_connections(self): if not self._connections: return - - if len(self._connections) > len(self._steps): + if not all([len(i) == 2 for i in self._connections]): + raise AppException("Invalid connections.") + steps_count = len(self._steps) + if len(self._connections) > max( + math.factorial(steps_count) / (2 * math.factorial(steps_count - 2)), 1 + ): raise AppValidationException( - "Invalid connections: more connections than steps." + "Invalid connections: duplicates in a connection group." ) possible_connections = set(range(1, len(self._steps) + 1)) diff --git a/tests/integration/steps/test_steps.py b/tests/integration/steps/test_steps.py index 27e38f40..6dbddf6d 100644 --- a/tests/integration/steps/test_steps.py +++ b/tests/integration/steps/test_steps.py @@ -1,12 +1,5 @@ -import json -import os -import tempfile -from pathlib import Path - -from numpy.ma.core import arange from src.superannotate import AppException from src.superannotate import SAClient -from tests import DATA_SET_PATH from tests.integration.base import BaseTestCase sa = SAClient() @@ -236,7 +229,8 @@ def test_create_invalid_connection(self): sa.set_project_steps( *args, connections=[ - [1, 2, 1], + [1, 2], + [2, 1], ] ) with self.assertRaisesRegexp(