Skip to content

Commit

Permalink
Tag custom ranking metrics as output columns (#1500)
Browse files Browse the repository at this point in the history
## Description

This PR:

- On write(), search for known custom ranking metrics and tag them as
outputs
- tag_output_column requires the preexistence of columns, so a new
_set_column_schema is created that can be called before the actual
columns are uploaded
- incidental fix on integration tests for performance columns


- [x] I have reviewed the [Guidelines for Contributing](CONTRIBUTING.md)
and the [Code of Conduct](CODE_OF_CONDUCT.md).

---------

Co-authored-by: felipe207 <felipe@whylabs.ai>
  • Loading branch information
FelipeAdachi and felipe207 committed Apr 17, 2024
1 parent c1385e1 commit 8b2809e
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 5 deletions.
8 changes: 4 additions & 4 deletions python/tests/api/writer/test_whylabs_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,16 +170,16 @@ def test_performance_column():
model_api = ModelsApi(writer._api_client)
response: EntitySchema = model_api.get_entity_schema(ORG_ID, MODEL_ID)
assert (
response["metrics"]["perf column"]["column"] == "col1"
and response["metrics"]["perf column"]["default_metric"] == "mean"
and response["metrics"]["perf column"]["label"] == "perf column"
response["metrics"]["perf_column"]["column"] == "col1"
and response["metrics"]["perf_column"]["default_metric"] == "mean"
and response["metrics"]["perf_column"]["label"] == "perf column"
)

# change it so we won't accidentally pass from previous state
status, _ = writer.tag_custom_performance_column("col1", "perf column", "median")
assert status
response = model_api.get_entity_schema(ORG_ID, MODEL_ID)
assert response["metrics"]["perf column"]["default_metric"] == "median"
assert response["metrics"]["perf_column"]["default_metric"] == "median"


@pytest.mark.load
Expand Down
49 changes: 48 additions & 1 deletion python/whylogs/api/writer/whylabs.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@
KNOWN_CUSTOM_PERFORMANCE_METRICS = {
"mean_average_precision_k_": "mean",
"accuracy_k_": "mean",
"mean_reciprocal_rank": "mean",
"reciprocal_rank": "mean",
"precision_k_": "mean",
"recall_k_": "mean",
"top_rank": "mean",
Expand All @@ -90,6 +90,18 @@
"sum_gain_k_": "mean",
}

KNOWN_CUSTOM_OUTPUT_METRICS = {
"mean_average_precision_k_": ("fractional", "continuous"),
"accuracy_k_": ("fractional", "continuous"),
"reciprocal_rank": ("fractional", "continuous"),
"precision_k_": ("fractional", "continuous"),
"recall_k_": ("fractional", "continuous"),
"top_rank": ("integral", "continuous"),
"average_precision_k_": ("fractional", "continuous"),
"norm_dis_cumul_gain_k_": ("fractional", "continuous"),
"sum_gain_k_": ("fractional", "continuous"),
}


def _check_whylabs_condition_count_uncompound() -> bool:
global _WHYLABS_SKIP_CONFIG_READ
Expand Down Expand Up @@ -668,6 +680,20 @@ def _write_segmented_result_set(self, file: SegmentedResultSet, **kwargs: Any) -

return and_status, "; ".join(messages)

def _tag_custom_output_metrics(self, view: Union[DatasetProfileView, SegmentedDatasetProfileView]):
if isinstance(view, DatasetProfileView):
column_names = view.get_columns().keys()
for column_name in column_names:
for perf_col in KNOWN_CUSTOM_OUTPUT_METRICS:
if column_name.startswith(perf_col):
data_type = KNOWN_CUSTOM_OUTPUT_METRICS[perf_col][0]
discreteness = KNOWN_CUSTOM_OUTPUT_METRICS[perf_col][1]
column_schema: ColumnSchema = ColumnSchema(
classifier="output", data_type=data_type, discreteness=discreteness # type: ignore
)
self._set_column_schema(column_name, column_schema=column_schema)
return

def _tag_custom_perf_metrics(self, view: Union[DatasetProfileView, SegmentedDatasetProfileView]):
if isinstance(view, DatasetProfileView):
column_names = view.get_columns().keys()
Expand Down Expand Up @@ -743,6 +769,7 @@ def write(self, file: Writable, **kwargs: Any) -> Tuple[bool, str]:
has_segments = isinstance(view, SegmentedDatasetProfileView)
has_performance_metrics = view.model_performance_metrics
self._tag_custom_perf_metrics(view)
self._tag_custom_output_metrics(view)
if not has_segments and not isinstance(view, DatasetProfileView):
raise ValueError(
"You must pass either a DatasetProfile or a DatasetProfileView in order to use this writer!"
Expand Down Expand Up @@ -1027,6 +1054,26 @@ def _column_schema_needs_update(self, column_schema: ColumnSchema, new_classific
return True
return existing_classification != new_classification

def _set_column_schema(self, column_name: str, column_schema: ColumnSchema):
model_api_instance = self._get_or_create_models_client()
try:
# TODO: remove when whylabs supports merge writes.
model_api_instance.put_entity_schema_column( # type: ignore
self._org_id, self._dataset_id, column_name, column_schema=column_schema
)
return (
200,
f"{column_name} schema set to {column_schema.classifier} {column_schema.data_type} {column_schema.discreteness}",
)
except ForbiddenException as e:
logger.exception(
f"Failed to set column outputs {self._org_id}/{self._dataset_id} for column name: ("
f"{column_name}) "
f"{self.whylabs_api_endpoint}"
f" with API token ID: {self.key_id}"
)
raise e

def _put_column_schema(self, column_name: str, value: str) -> Tuple[int, str]:
model_api_instance = self._get_or_create_models_client()

Expand Down

0 comments on commit 8b2809e

Please sign in to comment.