Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 49 additions & 5 deletions src/superannotate_databricks_connector/schemas/comment.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,15 @@
FloatType,
BooleanType,
MapType,
ArrayType
ArrayType,
IntegerType
)

from .shapes import get_bbox_schema

def get_comment_schema():
comment_schema = StructType([

def get_vector_comment_schema():
return StructType([
StructField("correspondence",
ArrayType(MapType(
StringType(),
Expand All @@ -23,12 +26,53 @@ def get_comment_schema():
StructField("createdBy", MapType(
StringType(),
StringType()),
True),
StructField("creationType", StringType(), True),
StructField("updatedAt", StringType(), True),
StructField("updatedBy", MapType(
StringType(),
StringType()),
True)
])


def get_video_timestamp_schema():
return StructType([
StructField("timestamp", IntegerType(), True),
StructField("points", get_bbox_schema(), True)
])


def get_video_comment_parameter_schema():
return StructType([
StructField("start", IntegerType(), True),
StructField("end", IntegerType, True),
StructField("timestamps", ArrayType(
get_video_timestamp_schema()), True)
])


def get_video_comment_schema():
return StructType([
StructField("correspondence",
ArrayType(MapType(
StringType(),
StringType())),
True),
StructField("start", IntegerType(), True),
StructField("end", IntegerType(), True),
StructField("createdAt", StringType(), True),
StructField("createdBy", MapType(
StringType(),
StringType()),
True),
StructField("creationType", StringType(), True),
StructField("updatedAt", StringType(), True),
StructField("updatedBy", MapType(
StringType(),
StringType()),
True)
True),
StructField("parameters",
ArrayType(get_video_comment_parameter_schema()), True)

])
return comment_schema
131 changes: 131 additions & 0 deletions src/superannotate_databricks_connector/schemas/shapes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
from pyspark.sql.types import (
StructType,
StructField,
FloatType,
ArrayType
)


def get_bbox_schema():
"""
Defines the schema of a bounding box

Args:
None

Returns:
StructType: Schema of a bbox
"""
return StructType([
StructField("x1", FloatType(), True),
StructField("y1", FloatType(), True),
StructField("x2", FloatType(), True),
StructField("y2", FloatType(), True)
])


def get_rbbox_schema():
"""
Defines the schema of a rotated bounding box
this contains one point for each corned

Args:
None

Returns:
StructType: Schema of a bbox
"""
return StructType([
StructField("x1", FloatType(), True),
StructField("y1", FloatType(), True),
StructField("x2", FloatType(), True),
StructField("y2", FloatType(), True),
StructField("x3", FloatType(), True),
StructField("y3", FloatType(), True),
StructField("x4", FloatType(), True),
StructField("y5", FloatType(), True)
])


def get_point_schema():
"""
Defines the schema of a point

Args:
None

Returns:
StructType: Schema of a point
"""
return StructType([
StructField("x", FloatType(), True),
StructField("y", FloatType(), True)
])


def get_cuboid_schema():
"""
Defines the schema of a cuboid (3d bounding box)

Args:
None

Returns:
StructType: Schema of a cuboid
"""
return StructType([
StructField("f1", get_point_schema(), True),
StructField("f2", get_point_schema(), True),
StructField("r1", get_point_schema(), True),
StructField("r2", get_point_schema(), True)
])


def get_ellipse_schema():
"""
Defines the schema of an ellipse

Args:
None

Returns:
StructType: Schema of an ellipse
"""
return StructType([
StructField("cx", FloatType(), True),
StructField("cy", FloatType(), True),
StructField("rx", FloatType(), True),
StructField("ty", FloatType(), True),
StructField("angle", FloatType(), True)
])


def get_polygon_schema():
"""
Defines the schema of a polygon. It contains a shell as well
as excluded points

Args:
None

Returns:
StructType: Schema of a polygon with holes
"""
return StructType([
StructField("points", ArrayType(FloatType()), True),
StructField("exclude", ArrayType(ArrayType(FloatType())), True)
])


def get_polyline_schema():
"""
Defines the schema of a polyline
A simple array of float

Args:
None

Returns:
ArrayType: Schema of a polygon with holes
"""
return ArrayType(FloatType())
25 changes: 25 additions & 0 deletions src/superannotate_databricks_connector/schemas/tag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from pyspark.sql.types import (
StructType,
StructField,
StringType,
IntegerType,
MapType,
ArrayType,
)


def get_tag_schema():
schema = StructType([
StructField("instance_type", StringType(), True),
StructField("classId", IntegerType(), True),
StructField("probability", IntegerType(), True),
StructField("attributes", ArrayType(MapType(StringType(),
StringType())),
True),
StructField("createdAt", StringType(), True),
StructField("createdBy", MapType(StringType(), StringType()), True),
StructField("creationType", StringType(), True),
StructField("updatedAt", StringType(), True),
StructField("updatedBy", MapType(StringType(), StringType()), True),
StructField("className", StringType(), True)])
return schema
70 changes: 21 additions & 49 deletions src/superannotate_databricks_connector/schemas/vector_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,45 +3,35 @@
StructField,
StringType,
IntegerType,
FloatType,
BooleanType,
MapType,
ArrayType
)
from .comment import get_comment_schema


def get_point_schema():
point_schema = StructType([
StructField("x", FloatType(), True),
StructField("y", FloatType(), True)
])
return point_schema


def get_cuboid_schema():
cuboid_points_schema = StructType([
StructField("f1", get_point_schema(), True),
StructField("f2", get_point_schema(), True),
StructField("r1", get_point_schema(), True),
StructField("r2", get_point_schema(), True)
])
return cuboid_points_schema
from .comment import get_vector_comment_schema
from .shapes import (
get_point_schema,
get_cuboid_schema,
get_bbox_schema,
get_ellipse_schema,
get_polygon_schema,
get_polyline_schema,
get_rbbox_schema
)
from .tag import get_tag_schema


def get_vector_instance_schema():
instance_schema = StructType([
return StructType([
StructField("instance_type", StringType(), True),
StructField("classId", IntegerType(), True),
StructField("probability", IntegerType(), True),
StructField("bbox_points", MapType(StringType(), FloatType()), True),
StructField("polygon_points", ArrayType(FloatType()), True),
StructField("polygon_exclude", ArrayType(ArrayType(FloatType())),
True),
StructField("cuboid_points", get_cuboid_schema(), True),
StructField("ellipse_points", MapType(StringType(), FloatType()),
True),
StructField("point_points", MapType(StringType(), FloatType()), True),
StructField("bbox", get_bbox_schema(), True),
StructField("rbbox", get_rbbox_schema(), True),
StructField("polygon", get_polygon_schema()),
StructField("cuboid", get_cuboid_schema(), True),
StructField("ellipse", get_ellipse_schema(), True),
StructField("polyline", get_polyline_schema(), True),
StructField("point", get_point_schema(), True),
StructField("groupId", IntegerType(), True),
StructField("locked", BooleanType(), True),
StructField("attributes", ArrayType(MapType(StringType(),
Expand All @@ -56,24 +46,6 @@ def get_vector_instance_schema():
StructField("updatedBy", MapType(StringType(), StringType()), True),
StructField("className", StringType(), True)
])
return instance_schema


def get_vector_tag_schema():
schema = StructType([
StructField("instance_type", StringType(), True),
StructField("classId", IntegerType(), True),
StructField("probability", IntegerType(), True),
StructField("attributes", ArrayType(MapType(StringType(),
StringType())),
True),
StructField("createdAt", StringType(), True),
StructField("createdBy", MapType(StringType(), StringType()), True),
StructField("creationType", StringType(), True),
StructField("updatedAt", StringType(), True),
StructField("updatedBy", MapType(StringType(), StringType()), True),
StructField("className", StringType(), True)])
return schema


def get_vector_schema():
Expand All @@ -90,7 +62,7 @@ def get_vector_schema():
StructField("instances", ArrayType(get_vector_instance_schema()),
True),
StructField("bounding_boxes", ArrayType(IntegerType()), True),
StructField("comments", ArrayType(get_comment_schema()), True),
StructField("tags", ArrayType(get_vector_tag_schema()), True)
StructField("comments", ArrayType(get_vector_comment_schema()), True),
StructField("tags", ArrayType(get_tag_schema()), True)
])
return schema
Loading