Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions src/superannotate/lib/app/interface/sdk_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -2379,7 +2379,6 @@ def upload_annotations_from_folder_to_project(
"""

project_name, folder_name = extract_project_folder(project)
project = controller.get_project_metadata(project_name).data

if recursive_subfolders:
logger.info(
Expand Down Expand Up @@ -2478,7 +2477,7 @@ def upload_preannotations_from_folder_to_project(
folder_name=folder_name,
annotation_paths=annotation_paths, # noqa: E203
client_s3_bucket=from_s3_bucket,
is_pre_annotations=True
is_pre_annotations=True,
)
if response.errors:
raise AppException(response.errors)
Expand Down
4 changes: 3 additions & 1 deletion src/superannotate/lib/app/interface/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ def validate(cls, value: Union[str]) -> Union[str]:
if cls.curtail_length and len(value) > cls.curtail_length:
value = value[: cls.curtail_length]
if value.lower() not in AnnotationStatus.values():
raise TypeError(f"Available statuses is {', '.join(AnnotationStatus.titles())}. ")
raise TypeError(
f"Available statuses is {', '.join(AnnotationStatus.titles())}. "
)
return value


Expand Down
83 changes: 46 additions & 37 deletions src/superannotate/lib/core/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,11 @@ def map_annotation_classes_name(annotation_classes, reporter: Reporter) -> dict:


def fill_annotation_ids(
annotations: dict,
annotation_classes_name_maps: dict,
templates: List[dict],
reporter: Reporter):
annotations: dict,
annotation_classes_name_maps: dict,
templates: List[dict],
reporter: Reporter,
):
annotation_classes_name_maps = annotation_classes_name_maps
if "instances" not in annotations:
return
Expand All @@ -67,7 +68,7 @@ def fill_annotation_ids(
annotation_classes_name_maps.update(unknown_classes)
template_name_id_map = {template["name"]: template["id"] for template in templates}
for annotation in (
i for i in annotations["instances"] if i.get("type", None) == "template"
i for i in annotations["instances"] if i.get("type", None) == "template"
):
annotation["templateId"] = template_name_id_map.get(
annotation.get("templateName", ""), -1
Expand All @@ -76,25 +77,35 @@ def fill_annotation_ids(
for annotation in [i for i in annotations["instances"] if "className" in i]:
annotation_class_name = annotation["className"]
if annotation_class_name not in annotation_classes_name_maps.keys():
reporter.log_warning(f"Couldn't find annotation class {annotation_class_name}")
reporter.log_warning(
f"Couldn't find annotation class {annotation_class_name}"
)
continue
annotation["classId"] = annotation_classes_name_maps[annotation_class_name]["id"]
annotation["classId"] = annotation_classes_name_maps[annotation_class_name][
"id"
]
for attribute in annotation["attributes"]:
if (
attribute["groupName"]
not in annotation_classes_name_maps[annotation_class_name]["attribute_groups"]
attribute["groupName"]
not in annotation_classes_name_maps[annotation_class_name][
"attribute_groups"
]
):
reporter.log_warning(f"Couldn't find annotation group {attribute['groupName']}.")
reporter.store_message("Couldn't find annotation groups", attribute["groupName"])
reporter.log_warning(
f"Couldn't find annotation group {attribute['groupName']}."
)
reporter.store_message(
"Couldn't find annotation groups", attribute["groupName"]
)
continue
attribute["groupId"] = annotation_classes_name_maps[annotation_class_name][
"attribute_groups"
][attribute["groupName"]]["id"]
if (
attribute["name"]
not in annotation_classes_name_maps[annotation_class_name][
"attribute_groups"
][attribute["groupName"]]["attributes"]
attribute["name"]
not in annotation_classes_name_maps[annotation_class_name][
"attribute_groups"
][attribute["groupName"]]["attributes"]
):
del attribute["groupId"]
reporter.log_warning(
Expand Down Expand Up @@ -145,9 +156,7 @@ def convert_timestamp(timestamp):
end_time = safe_time(convert_timestamp(parameter["end"]))

for timestamp_data in parameter["timestamps"]:
timestamp = safe_time(
convert_timestamp(timestamp_data["timestamp"])
)
timestamp = safe_time(convert_timestamp(timestamp_data["timestamp"]))
editor_instance["timeline"][timestamp] = {}

if timestamp == start_time:
Expand All @@ -157,9 +166,9 @@ def convert_timestamp(timestamp):
editor_instance["timeline"][timestamp]["active"] = False

if timestamp_data.get("points", None):
editor_instance["timeline"][timestamp][
editor_instance["timeline"][timestamp]["points"] = timestamp_data[
"points"
] = timestamp_data["points"]
]

if not class_name_mapper.get(meta["className"], None):
continue
Expand All @@ -169,41 +178,41 @@ def convert_timestamp(timestamp):
key = attribute["groupName"], attribute["name"]
existing_attributes_in_current_instance.add(key)
attributes_to_add = (
existing_attributes_in_current_instance - active_attributes
existing_attributes_in_current_instance - active_attributes
)
attributes_to_delete = (
active_attributes - existing_attributes_in_current_instance
active_attributes - existing_attributes_in_current_instance
)
if attributes_to_add or attributes_to_delete:
editor_instance["timeline"][timestamp][
"attributes"
] = defaultdict(list)
editor_instance["timeline"][timestamp]["attributes"] = defaultdict(
list
)
for new_attribute in attributes_to_add:
attr = {
"id": class_name_mapper[class_name]["attribute_groups"][
new_attribute[0]
]["attributes"][new_attribute[1]],
"groupId": class_name_mapper[class_name][
"attribute_groups"
][new_attribute[0]]["id"],
"groupId": class_name_mapper[class_name]["attribute_groups"][
new_attribute[0]
]["id"],
}
active_attributes.add(new_attribute)
editor_instance["timeline"][timestamp]["attributes"][
"+"
].append(attr)
editor_instance["timeline"][timestamp]["attributes"]["+"].append(
attr
)
for attribute_to_delete in attributes_to_delete:
attr = {
"id": class_name_mapper[class_name]["attribute_groups"][
attribute_to_delete[0]
]["attributes"][attribute_to_delete[1]],
"groupId": class_name_mapper[class_name][
"attribute_groups"
][attribute_to_delete[0]]["id"],
"groupId": class_name_mapper[class_name]["attribute_groups"][
attribute_to_delete[0]
]["id"],
}
active_attributes.remove(attribute_to_delete)
editor_instance["timeline"][timestamp]["attributes"][
"-"
].append(attr)
editor_instance["timeline"][timestamp]["attributes"]["-"].append(
attr
)

editor_data["instances"].append(editor_instance)
return editor_data
Expand Down
20 changes: 13 additions & 7 deletions src/superannotate/lib/core/reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@

class Reporter:
def __init__(
self,
log_info: bool = True,
log_warning: bool = True,
disable_progress_bar: bool = False
self,
log_info: bool = True,
log_warning: bool = True,
disable_progress_bar: bool = False,
):
self.logger = logging.getLogger("root")
self._log_info = log_info
Expand All @@ -31,11 +31,17 @@ def log_warning(self, value: str):
self.logger.warning(value)
self.warning_messages.append(value)

def start_progress(self, iterations: Union[int, range], description: str = "Processing"):
def start_progress(
self, iterations: Union[int, range], description: str = "Processing"
):
if isinstance(iterations, range):
self.progress_bar = tqdm.tqdm(iterations, desc=description, disable=self._disable_progress_bar)
self.progress_bar = tqdm.tqdm(
iterations, desc=description, disable=self._disable_progress_bar
)
else:
self.progress_bar = tqdm.tqdm(total=iterations, desc=description, disable=self._disable_progress_bar)
self.progress_bar = tqdm.tqdm(
total=iterations, desc=description, disable=self._disable_progress_bar
)

def finish_progress(self):
self.progress_bar.close()
Expand Down
13 changes: 11 additions & 2 deletions src/superannotate/lib/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ class AnnotationType(StrictStr):
@classmethod
def validate(cls, value: str) -> Union[str]:
if value not in ANNOTATION_TYPES.keys():
raise ValidationError([ErrorWrapper(TypeError(f"invalid value {value}"), "type")], cls)
raise ValidationError(
[ErrorWrapper(TypeError(f"invalid value {value}"), "type")], cls
)
return value


Expand Down Expand Up @@ -156,7 +158,9 @@ class VectorAnnotation(BaseModel):
def check_instances(cls, instance):
annotation_type = AnnotationType.validate(instance.get("type"))
if not annotation_type:
raise ValidationError([ErrorWrapper(TypeError("value not specified"), "type")], cls)
raise ValidationError(
[ErrorWrapper(TypeError("value not specified"), "type")], cls
)
result = validate_model(ANNOTATION_TYPES[annotation_type], instance)
if result[2]:
raise ValidationError(
Expand Down Expand Up @@ -219,3 +223,8 @@ class VideoAnnotation(BaseModel):
metadata: VideoMetaData
instances: List[VideoInstance]
tags: List[str]


class DocumentAnnotation(BaseModel):
instances: list
tags: List[str]
Loading