In [None]:
from __future__ import annotations

# Data processing status example

This example assumes you have started the OSS server using the dataset example located in the test
asset directory. From the rerun repository you can start this using the following command.

```shell
rerun server --dataset ./tests/assets/rrd/dataset
```

In [None]:
from datafusion import col, functions as F
from datetime import datetime
from pathlib import Path
from rerun.catalog import CatalogClient, DatasetEntry
from typing import TYPE_CHECKING
import pyarrow as pa
import tempfile

if TYPE_CHECKING:
    from collections.abc import Generator

CATALOG_URL = "rerun+http://localhost:51234"
DATASET_NAME = "dataset"

STATUS_TABLE_NAME = "status"
RESULTS_TABLE_NAME = "results"

In [None]:
def create_status_table(client: CatalogClient, directory: Path) -> DataFrame:
    if STATUS_TABLE_NAME in client.table_names():
        return client.get_table(name=STATUS_TABLE_NAME)
    
    schema = pa.schema([
        ("rerun_partition_id", pa.utf8()),
        ("is_complete", pa.bool_()),
        ("update_time", pa.timestamp(unit="ms")),
    ])
    url = f"file://{directory}/{STATUS_TABLE_NAME}"

    client.create_table(STATUS_TABLE_NAME, schema, url)
    return client.get_table(name=STATUS_TABLE_NAME)

In [None]:
def create_results_table(client: CatalogClient, directory: Path) -> DataFrame:
    if RESULTS_TABLE_NAME in client.table_names():
        return client.get_table(name=RESULTS_TABLE_NAME)
    
    schema = pa.schema([
        ("rerun_partition_id", pa.utf8()),
        ("first_log_time", pa.timestamp(unit="ns")),
        ("last_log_time", pa.timestamp(unit="ns")),
        ("first_position_obj1", pa.list_(pa.float32(), 3)),
        ("first_position_obj2", pa.list_(pa.float32(), 3)),
        ("first_position_obj3", pa.list_(pa.float32(), 3)),
    ])
    url = f"file://{directory}/{RESULTS_TABLE_NAME}"

    client.create_table(RESULTS_TABLE_NAME, schema, url)
    return client.get_table(name=RESULTS_TABLE_NAME)

In [None]:
def find_missing_partitions(
    partition_table: DataFrame,
    status_table: DataFrame
) -> List[str]:
    status_table = status_table.filter(col("is_complete") == True)
    partitions = partition_table.join(status_table, on="rerun_partition_id", how="anti").collect()
    return [r for rss in partitions for rs in rss for r in rs]

In [None]:
def process_partitions(client: ConnectionClient, dataset: DatasetEntry, partition_list: list[pa.ScalarValue]) -> None:
    client.append_to_table(
        STATUS_TABLE_NAME,
        rerun_partition_id=partition_list,
        is_complete=[False] * len(partition_list),
        update_time = [datetime.now()] * len(partition_list)
    )
    partition_list = [str(p) for p in partition_list]

    df = dataset.dataframe_query_view(index="time_1", contents="/**").filter_partition_id(*partition_list).df()

    df = (
        df
        .aggregate(
            "rerun_partition_id",
            [
                F.min(col("log_time")).alias("first_log_time"),
                F.max(col("log_time")).alias("last_log_time"),
                F.first_value(
                    col("/obj1:Points3D:positions")[0],
                    filter=col("/obj1:Points3D:positions").is_not_null(),
                    order_by=col("time_1")
                ).alias("first_position_obj1"),
                F.first_value(
                    col("/obj2:Points3D:positions")[0],
                    filter=col("/obj2:Points3D:positions").is_not_null(),
                    order_by=col("time_1")
                ).alias("first_position_obj2"),
                F.first_value(
                    col("/obj3:Points3D:positions")[0],
                    filter=col("/obj3:Points3D:positions").is_not_null(),
                    order_by=col("time_1")
                ).alias("first_position_obj3"),
            ]  
        )
    )

    df.write_table(RESULTS_TABLE_NAME)
    
    client.append_to_table(
        STATUS_TABLE_NAME,
        rerun_partition_id=partition_list,
        is_complete=[True] * len(partition_list), # Add the `True` value to prevent this from processing again
        update_time = [datetime.now()] * len(partition_list)
    )

In [None]:

with tempfile.TemporaryDirectory() as temp_dir:
    temp_path = Path(temp_dir)
    
    client = CatalogClient(CATALOG_URL)
    dataset = client.get_dataset(name=DATASET_NAME)
    
    status_table = create_status_table(client, temp_path)
    results_table = create_results_table(client, temp_path)

    # TODO(tsaucer) replace with partition table query
    partition_table = dataset.dataframe_query_view(index="time_1", contents="/**").df().select("rerun_partition_id").distinct()

    missing_partitions = None
    while missing_partitions is None or len(missing_partitions) != 0:
        missing_partitions = find_missing_partitions(partition_table, status_table)
        print(f"{len(missing_partitions)} of {partition_table.count()} partitions have not processed.")

        if len(missing_partitions) > 0:
            process_partitions(client, dataset, missing_partitions[0:3])

    display(results_table)

    display(status_table)