Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions .github/workflows/release.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
name: Publish Python 🐍 distributions 📦 to PyPI and TestPyPI

on:
release:
types: [prereleased,released]

jobs:
build-n-publish:
name: Build and publish Python 🐍 distributions 📦 to PyPI and TestPyPI
runs-on: ubuntu-20.04
steps:
- uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.8"
- name: Upgrade pip
run: >-
python -m
pip install
pip --upgrade
--user
- name: Install pypi/build
run: >-
python -m
pip install
build
--user
- name: Build a binary wheel and a source tarball
run: >-
python -m
build
--sdist
--wheel
--outdir dist/
.
- name: Publish distribution 📦 to PyPI
if: startsWith(github.ref, 'refs/tags')
uses: pypa/gh-action-pypi-publish@release/v1
with:
password: ${{ secrets.pypi_password }}
verbose: true
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

name = "superannotate_databricks_connector" # Required

version = "0.0.1dev1"
version = "0.0.2dev1"

description = "Custom functions to work with SuperAnnotate in Databricks"

Expand Down
8 changes: 8 additions & 0 deletions readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,14 @@ If you are running the tests for the first you first have to build the base dock
docker build -f Dockerfile.spark -t spark_docker_base .
```

### Build package

In the main directory, run the following to generate a .whl file.

```bash
python -m build
```

### Usage
First import the required function

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ def get_text_schema():
schema = StructType([
StructField("name", StringType(), True),
StructField("url", StringType(), True),
StructField("contentLength", IntegerType(), True),
StructField("projectId", IntegerType(), True),
StructField("status", StringType(), True),
StructField("annotatorEmail", StringType(), True),
Expand Down
20 changes: 19 additions & 1 deletion src/superannotate_databricks_connector/schemas/vector_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,23 @@ def get_vector_instance_schema():
return instance_schema


def get_vector_tag_schema():
schema = StructType([
StructField("instance_type", StringType(), True),
StructField("classId", IntegerType(), True),
StructField("probability", IntegerType(), True),
StructField("attributes", ArrayType(MapType(StringType(),
StringType())),
True),
StructField("createdAt", StringType(), True),
StructField("createdBy", MapType(StringType(), StringType()), True),
StructField("creationType", StringType(), True),
StructField("updatedAt", StringType(), True),
StructField("updatedBy", MapType(StringType(), StringType()), True),
StructField("className", StringType(), True)])
return schema


def get_vector_schema():
schema = StructType([
StructField("image_height", IntegerType(), True),
Expand All @@ -73,6 +90,7 @@ def get_vector_schema():
StructField("instances", ArrayType(get_vector_instance_schema()),
True),
StructField("bounding_boxes", ArrayType(IntegerType()), True),
StructField("comments", ArrayType(get_comment_schema()), True)
StructField("comments", ArrayType(get_comment_schema()), True),
StructField("tags", ArrayType(get_vector_tag_schema()), True)
])
return schema
5 changes: 3 additions & 2 deletions src/superannotate_databricks_connector/text.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from datetime import datetime
from superannotate_databricks_connector.schemas.text_schema import get_text_schema
from superannotate_databricks_connector.schemas.text_schema import (
get_text_schema
)


def convert_dates(instance):
Expand Down Expand Up @@ -40,7 +42,6 @@ def get_text_dataframe(annotations, spark):
flattened_item = {
"name": item["metadata"]["name"],
"url": item["metadata"]["url"],
"contentLength": item["metadata"]["contentLength"],
"projecId": item["metadata"]["projectId"],
"status": item["metadata"]["status"],
"annotatorEmail": item["metadata"]["annotatorEmail"],
Expand Down
6 changes: 4 additions & 2 deletions src/superannotate_databricks_connector/vector.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from superannotate_databricks_connector.schemas.vector_schema import get_vector_schema
from superannotate_databricks_connector.schemas.vector_schema import (
get_vector_schema
)


def process_comment(comment):
Expand Down Expand Up @@ -140,7 +142,7 @@ def get_vector_dataframe(annotations, spark, custom_id_map=None):
'qaEmail': item["metadata"]['qaEmail'],
"instances": [process_vector_object(instance, custom_id_map)
for instance in item["instances"]
if instance["type"] == "object"],
if instance["type"] != "tag"],
"bounding_boxes": get_boxes(item["instances"], custom_id_map),
"tags": [process_vector_tag(instance, custom_id_map)
for instance in item["instances"]
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
14 changes: 9 additions & 5 deletions tests/test_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@
class TestVectorInstances(unittest.TestCase):
def __init__(self, *args):
super().__init__(*args)
with open(os.path.join(DATA_SET_PATH, "vector/example_annotation.json"), "r") as f:
with open(os.path.join(DATA_SET_PATH,
"vector/example_annotation.json"), "r") as f:
data = json.load(f)

target_data = []
with open(os.path.join(DATA_SET_PATH, 'vector/expected_instances.json'),"r") as f:
with open(os.path.join(DATA_SET_PATH,
'vector/expected_instances.json'), "r") as f:
for line in f:
target_data.append(json.loads(line))

Expand Down Expand Up @@ -96,20 +98,22 @@ def test_get_boxes(self):
"y1": 2.1,
"y2": 18.9
},
"classId": 10229}]
"classId": 10229}]
target = [2, 1, 13, 22, 10228, 3, 2, 4, 19, 10229]
self.assertEqual(get_boxes(instances), target)


class TestVectorDataFrame(unittest.TestCase):
def test_vector_dataframe(self):
spark = SparkSession.builder.master("local").getOrCreate()
with open(os.path.join(DATA_SET_PATH, "vector/example_annotation.json"),"r") as f:
with open(os.path.join(DATA_SET_PATH,
"vector/example_annotation.json"), "r") as f:
data = json.load(f)

actual_df = get_vector_dataframe([data], spark)

expected_df = spark.read.parquet(os.path.join(DATA_SET_PATH, "vector/expected_df.parquet"))
expected_df = spark.read.parquet(os.path.join(
DATA_SET_PATH, "vector/expected_df.parquet"))
self.assertEqual(sorted(actual_df.collect()),
sorted(expected_df.collect()))

Expand Down