diff --git a/src/viam/app/data_client.py b/src/viam/app/data_client.py index 71840ac68..e35dd1fe7 100644 --- a/src/viam/app/data_client.py +++ b/src/viam/app/data_client.py @@ -174,33 +174,48 @@ async def tabular_data_by_filter( return data async def binary_data_by_filter( - self, - filter: Optional[Filter] = None, - dest: Optional[str] = None, + self, filter: Optional[Filter] = None, dest: Optional[str] = None, include_file_data: bool = True, num_files: Optional[int] = None ) -> List[BinaryData]: """Filter and download binary data. Args: - filter (viam.proto.app.data.Filter): Optional `Filter` specifying binary data to retrieve. No `Filter` implies all binary data. - dest (str): Optional filepath for writing retrieved data. + filter (Optional[viam.proto.app.data.Filter]): Optional `Filter` specifying binary data to retrieve. No `Filter` implies all + binary data. + dest (Optional[str]): Optional filepath for writing retrieved data. + include_file_data (bool): Boolean specifying whether to actually include the binary file data with each retrieved file. Defaults + to true (i.e., both the files' data and metadata are returned). + num_files (Optional[str]): Number of binary data to return. Passing 0 returns all binary data matching the filter no matter. + Defaults to 100 if no binary data is requested, otherwise 10. All binary data or the first `num_files` will be returned, + whichever comes first. + + Raises: + ValueError: If `num_files` is less than 0. Returns: List[bytes]: The binary data. """ + num_files = num_files if num_files else 10 if include_file_data else 100 + if num_files < 0: + raise ValueError("num_files must be at least 0.") filter = filter if filter else Filter() + limit = 1 if include_file_data else 100 last = "" data = [] - # `DataRequest`s are limited to 100 pieces of data, so we loop through calls until + # `DataRequest`s are limited in pieces of data, so we loop through calls until # we are certain we've received everything. while True: - data_request = DataRequest(filter=filter, limit=100, last=last) - request = BinaryDataByFilterRequest(data_request=data_request, count_only=False) - response: BinaryDataByFilterResponse = await self._data_client.BinaryDataByFilter(request, metadata=self._metadata) - if not response.data or len(response.data) == 0: + new_data, last = await self._binary_data_by_filter(filter=filter, limit=limit, include_binary=include_file_data, last=last) + if not new_data or len(new_data) == 0: break - data += [DataClient.BinaryData(data.binary, data.metadata) for data in response.data] - last = response.last + elif num_files != 0 and len(new_data) > num_files: + data += new_data[0:num_files] + break + else: + data += new_data + num_files -= len(new_data) + if num_files == 0: + break if dest: try: @@ -211,6 +226,12 @@ async def binary_data_by_filter( return data + async def _binary_data_by_filter(self, filter: Filter, limit: int, include_binary: bool, last: str) -> Tuple[List[BinaryData], str]: + data_request = DataRequest(filter=filter, limit=limit, last=last) + request = BinaryDataByFilterRequest(data_request=data_request, count_only=False, include_binary=include_binary) + response: BinaryDataByFilterResponse = await self._data_client.BinaryDataByFilter(request, metadata=self._metadata) + return [DataClient.BinaryData(data.binary, data.metadata) for data in response.data], response.last + async def binary_data_by_ids( self, binary_ids: List[BinaryID], diff --git a/tests/mocks/services.py b/tests/mocks/services.py index 95dc2b12c..4dab00b8e 100644 --- a/tests/mocks/services.py +++ b/tests/mocks/services.py @@ -528,6 +528,7 @@ async def BinaryDataByFilter(self, stream: Stream[BinaryDataByFilterRequest, Bin await stream.send_message(BinaryDataByFilterResponse()) return self.filter = request.data_request.filter + self.include_binary = request.include_binary await stream.send_message( BinaryDataByFilterResponse(data=[BinaryData(binary=data.data, metadata=data.metadata) for data in self.binary_response]) ) diff --git a/tests/test_data_client.py b/tests/test_data_client.py index 8575d6b43..55f817cb1 100644 --- a/tests/test_data_client.py +++ b/tests/test_data_client.py @@ -16,6 +16,7 @@ from .mocks.services import MockData +INCLUDE_BINARY = True COMPONENT_NAME = "component_name" COMPONENT_TYPE = "component_type" METHOD = "method" @@ -136,7 +137,8 @@ async def test_tabular_data_by_filter(self, service: MockData): async def test_binary_data_by_filter(self, service: MockData): async with ChannelFor([service]) as channel: client = DataClient(channel, DATA_SERVICE_METADATA) - binary_data = await client.binary_data_by_filter(filter=FILTER) + binary_data = await client.binary_data_by_filter(filter=FILTER, include_file_data=INCLUDE_BINARY) + assert service.include_binary == INCLUDE_BINARY assert binary_data == BINARY_RESPONSE self.assert_filter(filter=service.filter)