diff --git a/docs/examples/_server.py b/docs/examples/_server.py index 12b370445..83ea23b62 100644 --- a/docs/examples/_server.py +++ b/docs/examples/_server.py @@ -1,11 +1,13 @@ import asyncio +from datetime import datetime from grpclib.utils import graceful_exit from grpclib.server import Server, Stream from google.protobuf.struct_pb2 import Struct, Value from google.protobuf.timestamp_pb2 import Timestamp -from viam.utils import dict_to_struct, value_to_primitive +from viam.app.data_client import DataClient +from viam.utils import datetime_to_timestamp, dict_to_struct, value_to_primitive from viam.proto.app.data import ( AddBoundingBoxToImageByIDResponse, AddBoundingBoxToImageByIDRequest, @@ -19,6 +21,7 @@ BinaryDataByIDsResponse, BoundingBoxLabelsByFilterRequest, BoundingBoxLabelsByFilterResponse, + CaptureMetadata, DataServiceBase, DeleteBinaryDataByFilterRequest, DeleteBinaryDataByFilterResponse, @@ -169,7 +172,26 @@ class MockData(DataServiceBase): def __init__(self): self.tabular_data_requested = False - self.tabular_response = [{"PowerPct": 0, "IsPowered": False}, {"PowerPct": 0, "IsPowered": False}, {"Position": 0}] + self.tabular_response = [ + DataClient.TabularData( + {"PowerPct": 0, "IsPowered": False}, + CaptureMetadata(method_name="IsPowered"), + datetime(2022, 1, 1, 1, 1, 1), + datetime(2022, 12, 31, 23, 59, 59), + ), + DataClient.TabularData( + {"PowerPct": 0, "IsPowered": False}, + CaptureMetadata(location_id="loc-id"), + datetime(2023, 1, 2), + datetime(2023, 3, 4) + ), + DataClient.TabularData( + {"Position": 0}, + CaptureMetadata(), + datetime(2023, 5, 6), + datetime(2023, 7, 8), + ), + ] async def TabularDataByFilter(self, stream: Stream[TabularDataByFilterRequest, TabularDataByFilterResponse]) -> None: if self.tabular_data_requested: @@ -177,11 +199,21 @@ async def TabularDataByFilter(self, stream: Stream[TabularDataByFilterRequest, T return self.tabular_data_requested = True _ = await stream.recv_message() - n = len(self.tabular_response) - tabular_structs = [Struct()] * n - for i in range(n): - tabular_structs[i].update(self.tabular_response[i]) - await stream.send_message(TabularDataByFilterResponse(data=[TabularData(data=struct) for struct in tabular_structs])) + tabular_structs = [] + tabular_metadata = [data.metadata for data in self.tabular_response] + for idx, tabular_data in enumerate(self.tabular_response): + tabular_structs.append( + TabularData( + data=dict_to_struct(tabular_data.data), + metadata_index=idx, + time_requested=datetime_to_timestamp(tabular_data.time_requested), + time_received=datetime_to_timestamp(tabular_data.time_received) + ) + ) + await stream.send_message(TabularDataByFilterResponse( + data=tabular_structs, metadata=tabular_metadata, + ) + ) async def BinaryDataByFilter(self, stream: Stream[BinaryDataByFilterRequest, BinaryDataByFilterResponse]) -> None: pass diff --git a/src/viam/app/data_client.py b/src/viam/app/data_client.py index 24b72d2f4..8d796e08f 100644 --- a/src/viam/app/data_client.py +++ b/src/viam/app/data_client.py @@ -14,9 +14,11 @@ BinaryDataByIDsRequest, BinaryDataByIDsResponse, BinaryID, + BinaryMetadata, BoundingBoxLabelsByFilterRequest, BoundingBoxLabelsByFilterResponse, CaptureInterval, + CaptureMetadata, DataRequest, DataServiceStub, DeleteBinaryDataByFilterRequest, @@ -60,6 +62,54 @@ class DataClient: `ViamClient`. """ + class TabularData: + """Class representing a piece of tabular data and associated metadata. + + Args: + data (Mapping[str, Any]): the requested data. + metadata (viam.proto.app.data.CaptureMetadata): the metadata from the request. + time_requested (datetime): the time the data request was sent. + time_received (datetime): the time the requested data was received. + """ + + def __init__(self, data: Mapping[str, Any], metadata: CaptureMetadata, time_requested: datetime, time_received: datetime) -> None: + self.data = data + self.metadata = metadata + self.time_requested = time_requested + self.time_received = time_received + + data: Mapping[str, Any] + metadata: CaptureMetadata + time_requested: datetime + time_received: datetime + + def __str__(self) -> str: + return f"{self.data}\n{self.metadata}Time requested: {self.time_requested}\nTime received: {self.time_received}\n" + + def __eq__(self, other: "DataClient.TabularData") -> bool: + return str(self) == str(other) + + class BinaryData: + """Class representing a piece of binary data and associated metadata. + + Args: + data (bytes): the requested data. + metadata (viam.proto.app.data.BinaryMetadata): the metadata from the request. + """ + + def __init__(self, data: bytes, metadata: BinaryMetadata) -> None: + self.data = data + self.metadata = metadata + + data: bytes + metadata: BinaryMetadata + + def __str__(self) -> str: + return f"{self.data}\n{self.metadata}" + + def __eq__(self, other: "DataClient.BinaryData") -> bool: + return str(self) == str(other) + def __init__(self, channel: Channel, metadata: Mapping[str, str]): """Create a `DataClient` that maintains a connection to app. @@ -79,7 +129,7 @@ async def tabular_data_by_filter( self, filter: Optional[Filter] = None, dest: Optional[str] = None, - ) -> List[Mapping[str, Any]]: + ) -> List[TabularData]: """Filter and download tabular data. Args: @@ -102,13 +152,18 @@ async def tabular_data_by_filter( response: TabularDataByFilterResponse = await self._data_client.TabularDataByFilter(request, metadata=self._metadata) if not response.data or len(response.data) == 0: break - data += [struct_to_dict(struct.data) for struct in response.data] + data += [DataClient.TabularData( + struct_to_dict(struct.data), + response.metadata[struct.metadata_index], + struct.time_requested.ToDatetime(), + struct.time_received.ToDatetime(), + ) for struct in response.data] last = response.last if dest: try: file = open(dest, "w") - file.write(f"{data}") + file.write(f"{[str(d) for d in data]}") except Exception as e: LOGGER.error(f"Failed to write tabular data to file {dest}", exc_info=e) return data @@ -117,7 +172,7 @@ async def binary_data_by_filter( self, filter: Optional[Filter] = None, dest: Optional[str] = None, - ) -> List[bytes]: + ) -> List[BinaryData]: """Filter and download binary data. Args: @@ -139,13 +194,13 @@ async def binary_data_by_filter( response: BinaryDataByFilterResponse = await self._data_client.BinaryDataByFilter(request, metadata=self._metadata) if not response.data or len(response.data) == 0: break - data += [data.binary for data in response.data] + data += [DataClient.BinaryData(data.binary, data.metadata) for data in response.data] last = response.last if dest: try: file = open(dest, "w") - file.write(f"{data}") + file.write(f"{[str(d) for d in data]}") except Exception as e: LOGGER.error(f"Failed to write binary data to file {dest}", exc_info=e) @@ -155,7 +210,7 @@ async def binary_data_by_ids( self, binary_ids: List[BinaryID], dest: Optional[str] = None, - ) -> List[bytes]: + ) -> List[BinaryData]: """Filter and download binary data. Args: @@ -176,7 +231,7 @@ async def binary_data_by_ids( file.write(f"{response.data}") except Exception as e: LOGGER.error(f"Failed to write binary data to file {dest}", exc_info=e) - return [binary_data.binary for binary_data in response.data] + return [DataClient.BinaryData(data.binary, data.metadata) for data in response.data] async def delete_tabular_data_by_filter(self, filter: Optional[Filter]) -> int: """Filter and delete tabular data. diff --git a/tests/mocks/services.py b/tests/mocks/services.py index 9eaf30449..8f273b4b5 100644 --- a/tests/mocks/services.py +++ b/tests/mocks/services.py @@ -1,6 +1,5 @@ from typing import Any, Dict, List, Mapping, Optional, Union -from google.protobuf.struct_pb2 import Struct from grpclib.server import Stream from PIL import Image @@ -191,11 +190,12 @@ SensorsServiceBase, ) from viam.proto.service.vision import Classification, Detection +from viam.app.data_client import DataClient from viam.services.mlmodel import File, LabelType, Metadata, MLModel, TensorInfo from viam.services.navigation import Navigation from viam.services.slam import SLAM from viam.services.vision import Vision -from viam.utils import ValueTypes, struct_to_dict +from viam.utils import ValueTypes, datetime_to_timestamp, dict_to_struct, struct_to_dict class MockVision(Vision): @@ -483,8 +483,8 @@ async def do_command(self, command: Mapping[str, ValueTypes], *, timeout: Option class MockData(DataServiceBase): def __init__( self, - tabular_response: List[Mapping[str, Any]], - binary_response: List[bytes], + tabular_response: List[DataClient.TabularData], + binary_response: List[DataClient.BinaryData], delete_remove_response: int, tags_response: List[str], bbox_labels_response: List[str], @@ -504,11 +504,20 @@ async def TabularDataByFilter(self, stream: Stream[TabularDataByFilterRequest, T await stream.send_message(TabularDataByFilterResponse(data=None)) return self.filter = request.data_request.filter - n = len(self.tabular_response) - tabular_response_structs = [Struct()] * n - for i in range(n): - tabular_response_structs[i].update(self.tabular_response[i]) - await stream.send_message(TabularDataByFilterResponse(data=[TabularData(data=struct) for struct in tabular_response_structs])) + tabular_response_structs = [] + tabular_metadata = [data.metadata for data in self.tabular_response] + for idx, tabular_data in enumerate(self.tabular_response): + tabular_response_structs.append( + TabularData( + data=dict_to_struct(tabular_data.data), + metadata_index=idx, + time_requested=datetime_to_timestamp(tabular_data.time_requested), + time_received=datetime_to_timestamp(tabular_data.time_received) + ) + ) + await stream.send_message(TabularDataByFilterResponse( + data=tabular_response_structs, metadata=tabular_metadata) + ) self.was_tabular_data_requested = True async def BinaryDataByFilter(self, stream: Stream[BinaryDataByFilterRequest, BinaryDataByFilterResponse]) -> None: @@ -518,14 +527,18 @@ async def BinaryDataByFilter(self, stream: Stream[BinaryDataByFilterRequest, Bin await stream.send_message(BinaryDataByFilterResponse()) return self.filter = request.data_request.filter - await stream.send_message(BinaryDataByFilterResponse(data=[BinaryData(binary=binary_data) for binary_data in self.binary_response])) + await stream.send_message(BinaryDataByFilterResponse( + data=[BinaryData(binary=data.data, metadata=data.metadata) for data in self.binary_response]) + ) self.was_binary_data_requested = True async def BinaryDataByIDs(self, stream: Stream[BinaryDataByIDsRequest, BinaryDataByIDsResponse]) -> None: request = await stream.recv_message() assert request is not None self.binary_ids = request.binary_ids - await stream.send_message(BinaryDataByIDsResponse(data=[BinaryData(binary=binary_data) for binary_data in self.binary_response])) + await stream.send_message(BinaryDataByIDsResponse( + data=[BinaryData(binary=data.data, metadata=data.metadata) for data in self.binary_response]) + ) async def DeleteTabularDataByFilter(self, stream: Stream[DeleteTabularDataByFilterRequest, DeleteTabularDataByFilterResponse]) -> None: request = await stream.recv_message() diff --git a/tests/test_data_client.py b/tests/test_data_client.py index 47198e391..d1e1c991c 100644 --- a/tests/test_data_client.py +++ b/tests/test_data_client.py @@ -1,5 +1,4 @@ import pytest -from datetime import datetime from typing import List from grpclib.testing import ChannelFor @@ -7,10 +6,12 @@ from viam.app.data_client import DataClient from viam.proto.app.data import ( - Filter, + Annotations, BinaryID, - CaptureInterval, - TagsFilter + BinaryMetadata, + BoundingBox, + CaptureMetadata, + Filter, ) from .mocks.services import MockData @@ -26,16 +27,20 @@ LOCATION_IDS = [LOCATION_ID] ORG_ID = "organization_id" ORG_IDS = [ORG_ID] -MIME_TYPES = ["mime_type"] -START_DATETIME = datetime(2001, 1, 1, 1, 1, 1) -END_DATETIME = datetime(2001, 1, 1, 1, 1, 1) +MIME_TYPE = "mime_type" +MIME_TYPES = [MIME_TYPE] +URI = "some.robot.uri" SECONDS_START = 978310861 NANOS_START = 0 SECONDS_END = 978310861 NANOS_END = 0 +START_TS = Timestamp(seconds=SECONDS_START, nanos=NANOS_START) +END_TS = Timestamp(seconds=SECONDS_END, nanos=NANOS_END) +START_DATETIME = START_TS.ToDatetime() +END_DATETIME = END_TS.ToDatetime() TAGS = ["tag"] BBOX_LABELS = ["bbox_label"] -FILTER = Filter( +FILTER = DataClient.create_filter( component_name=COMPONENT_NAME, component_type=COMPONENT_TYPE, method=METHOD, @@ -46,21 +51,12 @@ location_ids=LOCATION_IDS, organization_ids=ORG_IDS, mime_type=MIME_TYPES, - interval=CaptureInterval( - start=Timestamp( - seconds=SECONDS_START, - nanos=NANOS_START, - ), - end=Timestamp( - seconds=SECONDS_END, - nanos=NANOS_END - ) - ), - tags_filter=TagsFilter( - tags=TAGS - ), - bbox_labels=BBOX_LABELS, + start_time=START_DATETIME, + end_time=END_DATETIME, + tags=TAGS, + bbox_labels=BBOX_LABELS ) + FILE_ID = "file_id" BINARY_IDS = [BinaryID( file_id=FILE_ID, @@ -68,25 +64,49 @@ location_id=LOCATION_ID )] BINARY_DATA = b'binary_data' -TIMESTAMPS = [( - Timestamp( - seconds=SECONDS_START, - nanos=NANOS_START - ), - Timestamp( - seconds=SECONDS_END, - nanos=NANOS_END - ) -)] -TABULAR_DATA = [{"key": "value"}] FILE_NAME = "file_name" FILE_EXT = "file_extension" +BBOX_LABEL = "bbox_label" +BBOX_LABELS_RESPONSE = [BBOX_LABEL] +BBOX = BoundingBox( + id="id", + label=BBOX_LABEL, + x_min_normalized=0, + y_min_normalized=1, + x_max_normalized=2, + y_max_normalized=3, +) +BBOXES = [BBOX] +TABULAR_DATA = {"key": "value"} +TABULAR_METADATA = CaptureMetadata( + organization_id=ORG_ID, + location_id=LOCATION_ID, + robot_name=ROBOT_NAME, + robot_id=ROBOT_ID, + part_name=PART_NAME, + part_id=PART_ID, + component_type=COMPONENT_TYPE, + component_name=COMPONENT_NAME, + method_name=METHOD, + method_parameters={}, + tags=TAGS, + mime_type=MIME_TYPE, +) +BINARY_METADATA = BinaryMetadata( + id="id", + capture_metadata=TABULAR_METADATA, + time_requested=START_TS, + time_received=END_TS, + file_name=FILE_NAME, + file_ext=FILE_EXT, + uri=URI, + annotations=Annotations(bboxes=BBOXES), +) -TABULAR_RESPONSE = TABULAR_DATA -BINARY_RESPONSE = [BINARY_DATA] +TABULAR_RESPONSE = [DataClient.TabularData(TABULAR_DATA, TABULAR_METADATA, START_DATETIME, END_DATETIME)] +BINARY_RESPONSE = [DataClient.BinaryData(BINARY_DATA, BINARY_METADATA)] DELETE_REMOVE_RESPONSE = 1 TAGS_RESPONSE = ["tag"] -BBOX_LABELS_RESPONSE = ["bbox_label"] AUTH_TOKEN = "auth_token" DATA_SERVICE_METADATA = {"authorization": f"Bearer {AUTH_TOKEN}"} @@ -108,22 +128,7 @@ class TestClient: async def test_tabular_data_by_filter(self, service: MockData): async with ChannelFor([service]) as channel: client = DataClient(channel, DATA_SERVICE_METADATA) - tabular_data = await client.tabular_data_by_filter(filter=client.create_filter( - component_name=COMPONENT_NAME, - component_type=COMPONENT_TYPE, - method=METHOD, - robot_name=ROBOT_NAME, - robot_id=ROBOT_ID, - part_name=PART_NAME, - part_id=PART_ID, - location_ids=LOCATION_IDS, - organization_ids=ORG_IDS, - mime_type=MIME_TYPES, - start_time=START_DATETIME, - end_time=START_DATETIME, - tags=TAGS, - bbox_labels=BBOX_LABELS - )) + tabular_data = await client.tabular_data_by_filter(filter=FILTER) assert tabular_data == TABULAR_RESPONSE self.assert_filter(filter=service.filter)