Skip to content

Commit

Permalink
feat: allow rb.load fetch records in batches passing the from_id
Browse files Browse the repository at this point in the history
…argument

Thanks to @maxserras

(cherry picked from commit 36dbb4f)

fix: normalize id sort config (#1660)
(cherry picked from commit ca8ea92)
  • Loading branch information
maxserras authored and frascuchon committed Aug 22, 2022
1 parent 854a972 commit 3e6344a
Show file tree
Hide file tree
Showing 16 changed files with 164 additions and 30 deletions.
39 changes: 33 additions & 6 deletions src/rubrix/client/api.py
Expand Up @@ -408,24 +408,50 @@ def load(
query: Optional[str] = None,
ids: Optional[List[Union[str, int]]] = None,
limit: Optional[int] = None,
id_from: Optional[str] = None,
as_pandas=None,
) -> Dataset:
"""Loads a Rubrix dataset.
Args:
name: The dataset name.
query: An ElasticSearch query with the
Parameters:
-----------
name:
The dataset name.
query:
An ElasticSearch query with the
`query string syntax <https://rubrix.readthedocs.io/en/stable/guides/queries.html>`_
ids: If provided, load dataset records with given ids.
limit: The number of records to retrieve.
as_pandas: DEPRECATED! To get a pandas DataFrame do ``rb.load('my_dataset').to_pandas()``.
ids:
If provided, load dataset records with given ids.
limit:
The number of records to retrieve.
id_from:
If provided, starts gathering the records starting from that Record. As the Records returned with the
load method are sorted by ID, ´id_from´ can be used to load using batches.
as_pandas:
DEPRECATED! To get a pandas DataFrame do ``rb.load('my_dataset').to_pandas()``.
Returns:
--------
A Rubrix dataset.
Examples:
**Basic Loading**: load the samples sorted by their ID
>>> import rubrix as rb
>>> dataset = rb.load(name="example-dataset")
**Iterate over a large dataset:**
When dealing with a large dataset you might want to load it in batches to optimize memory consumption
and avoid network timeouts. To that end, a simple batch-iteration over the whole database can be done
employing the `from_id` parameter. This parameter will act as a delimiter, retrieving the N items after
the given id, where N is determined by the `limit` parameter. **NOTE** If
no `limit` is given the whole dataset after that ID will be retrieved.
>>> import rubrix as rb
>>> dataset_batch_1 = rb.load(name="example-dataset", limit=1000)
>>> dataset_batch_2 = rb.load(name="example-dataset", limit=1000, id_from=dataset_batch_1[-1].id)
"""
if as_pandas is False:
warnings.warn(
Expand Down Expand Up @@ -473,6 +499,7 @@ def load(
name=name,
request=request_class(ids=ids, query_text=query),
limit=limit,
id_from=id_from,
)

records = [sdk_record.to_client() for sdk_record in response.parsed]
Expand Down
11 changes: 10 additions & 1 deletion src/rubrix/client/sdk/commons/api.py
Expand Up @@ -26,7 +26,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from typing import Any, List, Type, TypeVar, Union
from typing import Any, List, Type, TypeVar, Union, Optional, Dict

import httpx

Expand All @@ -49,6 +49,15 @@
}


def build_param_dict(id_from: Optional[str], limit: Optional[int]) -> Optional[Dict[str, Union[str, int]]]:
params = {}
if id_from:
params['id_from'] = id_from
if limit:
params['limit'] = limit
return params


def bulk(
client: AuthenticatedClient,
name: str,
Expand Down
7 changes: 6 additions & 1 deletion src/rubrix/client/sdk/text2text/api.py
Expand Up @@ -21,19 +21,24 @@
from rubrix.client.sdk.commons.models import ErrorMessage, HTTPValidationError, Response
from rubrix.client.sdk.text2text.models import Text2TextQuery, Text2TextRecord

from rubrix.client.sdk.commons.api import build_param_dict


def data(
client: AuthenticatedClient,
name: str,
request: Optional[Text2TextQuery] = None,
limit: Optional[int] = None,
id_from: Optional[str] = None,
) -> Response[Union[List[Text2TextRecord], HTTPValidationError, ErrorMessage]]:

path = f"/api/datasets/{name}/Text2Text/data"
params = build_param_dict(id_from, limit)

with client.stream(
method="POST",
path=path,
params={"limit": limit} if limit else None,
params=params if params else None,
json=request.dict() if request else {},
) as response:
return build_data_response(response=response, data_type=Text2TextRecord)
10 changes: 7 additions & 3 deletions src/rubrix/client/sdk/text_classification/api.py
Expand Up @@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Union
from typing import List, Optional, Union, Dict

import httpx

Expand All @@ -27,20 +27,24 @@
TextClassificationRecord,
)

from rubrix.client.sdk.commons.api import build_param_dict


def data(
client: AuthenticatedClient,
name: str,
request: Optional[TextClassificationQuery] = None,
limit: Optional[int] = None,
id_from: Optional[str] = None,
) -> Response[Union[List[TextClassificationRecord], HTTPValidationError, ErrorMessage]]:

path = f"/api/datasets/{name}/TextClassification/data"

params = build_param_dict(id_from, limit)

with client.stream(
method="POST",
path=path,
params={"limit": limit} if limit else None,
params=params if params else None,
json=request.dict() if request else {},
) as response:
return build_data_response(
Expand Down
7 changes: 6 additions & 1 deletion src/rubrix/client/sdk/token_classification/api.py
Expand Up @@ -25,21 +25,26 @@
TokenClassificationRecord,
)

from rubrix.client.sdk.commons.api import build_param_dict


def data(
client: AuthenticatedClient,
name: str,
request: Optional[TokenClassificationQuery] = None,
limit: Optional[int] = None,
id_from: Optional[str] = None,
) -> Response[
Union[List[TokenClassificationRecord], HTTPValidationError, ErrorMessage]
]:

path = f"/api/datasets/{name}/TokenClassification/data"
params = build_param_dict(id_from, limit)

with client.stream(
path=path,
method="POST",
params={"limit": limit} if limit else None,
params=params if params else None,
json=request.dict() if request else {},
) as response:
return build_data_response(
Expand Down
5 changes: 4 additions & 1 deletion src/rubrix/server/apis/v0/handlers/text2text.py
Expand Up @@ -245,6 +245,7 @@ async def stream_data(
service: Text2TextService = Depends(text2text_service),
datasets: DatasetsService = Depends(DatasetsService.get_instance),
current_user: User = Security(auth.get_user, scopes=[]),
id_from: Optional[str] = None
) -> StreamingResponse:
"""
Creates a data stream over dataset records
Expand All @@ -265,6 +266,8 @@ async def stream_data(
The datasets service
current_user:
Request user
id_from:
If provided, read the samples after this record ID
"""
query = query or Text2TextQuery()
Expand All @@ -275,7 +278,7 @@ async def stream_data(
workspace=common_params.workspace,
as_dataset_class=TaskFactory.get_task_dataset(TASK_TYPE),
)
data_stream = service.read_dataset(dataset, query=query)
data_stream = service.read_dataset(dataset, query=query, id_from=id_from, limit=limit)

return scan_data_response(
data_stream=data_stream,
Expand Down
7 changes: 6 additions & 1 deletion src/rubrix/server/apis/v0/handlers/text_classification.py
Expand Up @@ -180,6 +180,7 @@ def search_records(
current_user:
The current request user
Returns
-------
The search results data
Expand Down Expand Up @@ -262,6 +263,7 @@ async def stream_data(
name: str,
query: Optional[TextClassificationQuery] = None,
common_params: CommonTaskQueryParams = Depends(),
id_from: Optional[str] = None,
limit: Optional[int] = Query(None, description="Limit loaded records", gt=0),
service: TextClassificationService = Depends(
TextClassificationService.get_instance
Expand Down Expand Up @@ -289,6 +291,9 @@ async def stream_data(
current_user:
Request user
id_from:
Search after the given record ID
"""
query = query or TextClassificationQuery()
dataset = datasets.find_by_name(
Expand All @@ -299,7 +304,7 @@ async def stream_data(
as_dataset_class=TaskFactory.get_task_dataset(TASK_TYPE),
)

data_stream = service.read_dataset(dataset, query=query)
data_stream = service.read_dataset(dataset, query=query, id_from=id_from, limit=limit)
return scan_data_response(
data_stream=data_stream,
limit=limit,
Expand Down
5 changes: 4 additions & 1 deletion src/rubrix/server/apis/v0/handlers/token_classification.py
Expand Up @@ -262,6 +262,7 @@ async def stream_data(
service: TokenClassificationService = Depends(token_classification_service),
datasets: DatasetsService = Depends(DatasetsService.get_instance),
current_user: User = Security(auth.get_user, scopes=[]),
id_from: Optional[str] = None,
) -> StreamingResponse:
"""
Creates a data stream over dataset records
Expand All @@ -282,6 +283,8 @@ async def stream_data(
The datasets service
current_user:
Request user
id_from:
If provided, read the samples after this record ID
"""
query = query or TokenClassificationQuery()
Expand All @@ -292,7 +295,7 @@ async def stream_data(
workspace=common_params.workspace,
as_dataset_class=TaskFactory.get_task_dataset(TASK_TYPE),
)
data_stream = service.read_dataset(dataset=dataset, query=query)
data_stream = service.read_dataset(dataset=dataset, query=query, id_from=id_from, limit=limit)

return scan_data_response(
data_stream=data_stream,
Expand Down
5 changes: 5 additions & 0 deletions src/rubrix/server/daos/datasets.py
Expand Up @@ -33,6 +33,7 @@
BaseDatasetDB = TypeVar("BaseDatasetDB", bound=DatasetDB)

NO_WORKSPACE = ""
MAX_NUMBER_OF_LISTED_DATASETS = 2500


class SettingsDB(BaseModel):
Expand Down Expand Up @@ -120,6 +121,9 @@ def list_datasets(

docs = self._es.list_documents(
index=DATASETS_INDEX_NAME,
fetch_once=True,
# TODO(@frascuchon): include id as part of the document as keyword to enable sorting by id
size=MAX_NUMBER_OF_LISTED_DATASETS,
query={
"query": query_helpers.filters.boolean_filter(
should_filters=filters, minimum_should_match=len(filters)
Expand Down Expand Up @@ -239,6 +243,7 @@ def find_by_name(
results = self._es.list_documents(
index=DATASETS_INDEX_NAME,
query={"query": {"term": {"name.keyword": name}}},
fetch_once=True,
)
results = list(results)
if len(results) == 0:
Expand Down
1 change: 1 addition & 0 deletions src/rubrix/server/daos/models/records.py
Expand Up @@ -31,6 +31,7 @@ class RecordSearch(BaseModel):
The elasticsearch sort order
aggregations:
The elasticsearch search aggregations
"""

query: Optional[Dict[str, Any]] = None
Expand Down
29 changes: 26 additions & 3 deletions src/rubrix/server/daos/records.py
Expand Up @@ -268,7 +268,9 @@ def __normalize_sort_config__(
def scan_dataset(
self,
dataset: BaseDatasetDB,
limit: int = 1000,
search: Optional[RecordSearch] = None,
id_from: Optional[str] = None,
) -> Iterable[Dict[str, Any]]:
"""
Iterates over a dataset records
Expand All @@ -279,19 +281,40 @@ def scan_dataset(
The dataset
search:
The search parameters. Optional
limit:
Batch size to extract, only works if an `id_from` is provided
id_from:
From which ID should we start iterating
Returns
-------
An iterable over found dataset records
"""
index = dataset_records_index(dataset.id)
search = search or RecordSearch()

sort_cfg = self.__normalize_sort_config__(
index=index, sort=[{"id": {"order": "asc"}}]
)
es_query = {
"query": search.query or {"match_all": {}},
"highlight": self.__configure_query_highlight__(task=dataset.task),
"sort": sort_cfg, # Sort the search so the consistency is maintained in every search
}
docs = self._es.list_documents(
dataset_records_index(dataset.id), query=es_query
)

if id_from:
# Scroll method does not accept read_after, thus, this case is handled as a search
es_query["search_after"] = [id_from]
results = self._es.search(index=index, query=es_query, size=limit)
hits = results["hits"]
docs = hits["hits"]

else:
docs = self._es.list_documents(
index,
query=es_query,
sort_cfg=sort_cfg,
)
for doc in docs:
yield self.__esdoc2record__(doc)

Expand Down

0 comments on commit 3e6344a

Please sign in to comment.