Skip to content

Commit

Permalink
Merge pull request #25 from vespa-engine/tgm/add-progress-info-collec…
Browse files Browse the repository at this point in the history
…t-data

track data collection progress
  • Loading branch information
Thiago G. Martins committed Oct 19, 2020
2 parents 24cb601 + f401971 commit 1f8f251
Showing 1 changed file with 26 additions and 3 deletions.
29 changes: 26 additions & 3 deletions vespa/application.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.

from typing import Optional, Dict, Tuple, List
import sys
from typing import Optional, Dict, Tuple, List, IO
from pandas import DataFrame
from requests import post
from requests.models import Response
Expand All @@ -16,6 +17,7 @@ def __init__(
port: Optional[int] = None,
deployment_message: Optional[List[str]] = None,
cert: Optional[str] = None,
output_file: IO = sys.stdout,
) -> None:
"""
Establish a connection with a Vespa application.
Expand All @@ -24,12 +26,14 @@ def __init__(
:param port: Vespa instance port.
:param deployment_message: Message returned by Vespa engine after deployment. Used internally by deploy methods.
:param cert: Path to certificate and key file.
:param output_file: Output file to write output messages.
>>> Vespa(url = "https://cord19.vespa.ai")
>>> Vespa(url = "http://localhost", port = 8080)
>>> Vespa(url = "https://api.vespa-external.aws.oath.cloud", port = 4443, cert = "/path/to/cert-and-key.pem")
"""
self.output_file = output_file
self.url = url
self.port = port
self.deployment_message = deployment_message
Expand Down Expand Up @@ -202,6 +206,7 @@ def collect_training_data(
number_additional_docs: int,
relevant_score: int = 1,
default_score: int = 0,
show_progress: Optional[int] = None,
**kwargs
) -> DataFrame:
"""
Expand All @@ -213,14 +218,32 @@ def collect_training_data(
:param number_additional_docs: Number of additional documents to retrieve for each relevant document.
:param relevant_score: Score to assign to relevant documents. Default to 1.
:param default_score: Score to assign to the additional documents that are not relevant. Default to 0.
:param show_progress: Prints the the current point being collected every `show_progress` step. Default to None,
in which case progress is not printed.
:param kwargs: Extra keyword arguments to be included in the Vespa Query.
:return: DataFrame containing document id (document_id), query id (query_id), scores (relevant)
and vespa rank features returned by the Query model RankProfile used.
"""

training_data = []
for query_data in labelled_data:
for doc_data in query_data["relevant_docs"]:
number_queries = len(labelled_data)
idx_total = 0
for query_idx, query_data in enumerate(labelled_data):
number_relevant_docs = len(query_data["relevant_docs"])
for doc_idx, doc_data in enumerate(query_data["relevant_docs"]):
idx_total += 1
if (show_progress is not None) and (idx_total % show_progress == 0):
print(
"Query {}/{}, Doc {}/{}. Query id: {}. Doc id: {}".format(
query_idx,
number_queries,
doc_idx,
number_relevant_docs,
query_data["query_id"],
doc_data["id"],
),
file=self.output_file,
)
training_data_point = self.collect_training_data_point(
query=query_data["query"],
query_id=query_data["query_id"],
Expand Down

0 comments on commit 1f8f251

Please sign in to comment.