Skip to content

Commit

Permalink
Fix bug that prevents prefixes from being omitted in udfs
Browse files Browse the repository at this point in the history
The types for a udf schema allow the prefix to be None, but our code
doesn't work in that scenario. In addition, the annotation interface for
udfs don't even let you have a None prefix. This change fixes both of
those issues.
  • Loading branch information
naddeoa committed Feb 8, 2024
1 parent 20e62d6 commit f8ff9a7
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 2 deletions.
12 changes: 12 additions & 0 deletions python/tests/experimental/core/test_udf_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,14 @@ def f2(x: Union[Dict[str, List], pd.DataFrame]) -> Union[Dict[str, List], pd.Dat
return pd.DataFrame({"foo": x["xx1"], "bar": x["xx2"]})


@register_multioutput_udf(["xx1", "xx2"], no_prefix=True)
def no_prefix_udf(x: Union[Dict[str, List], pd.DataFrame]) -> Union[Dict[str, List], pd.DataFrame]:
if isinstance(x, dict):
return {"foo": [x["xx1"][0]], "bar": [x["xx2"][0]]}
else:
return pd.DataFrame({"foo": x["xx1"], "bar": x["xx2"]})


def test_multioutput_udf_row() -> None:
schema = udf_schema()
row = {"xx1": 42, "xx2": 3.14}
Expand All @@ -79,6 +87,8 @@ def test_multioutput_udf_row() -> None:
assert results.get_column("f1.bar") is not None
assert results.get_column("blah.foo") is not None
assert results.get_column("blah.bar") is not None
assert results.get_column("foo") is not None
assert results.get_column("bar") is not None


def test_multioutput_udf_dataframe() -> None:
Expand All @@ -89,6 +99,8 @@ def test_multioutput_udf_dataframe() -> None:
assert results.get_column("f1.bar") is not None
assert results.get_column("blah.foo") is not None
assert results.get_column("blah.bar") is not None
assert results.get_column("foo") is not None
assert results.get_column("bar") is not None


@register_dataset_udf(["col1"], schema_name="unit-tests")
Expand Down
11 changes: 9 additions & 2 deletions python/whylogs/experimental/core/udf_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,13 @@ def _apply_udf_on_dataframe(
udf(Union[Dict[str, List], pd.DataFrame]) -> Union[Dict[str, List], pd.DataFrame]
"""

def add_prefix(col):
return prefix + "." + col if prefix else col

try:
# TODO: I think it's OKAY if udf returns a dictionary
udf_output = pd.DataFrame(udf(pandas))
udf_output = udf_output.rename(columns={old: prefix + "." + old for old in udf_output.keys()}) # type: ignore
udf_output = udf_output.rename(columns={old: add_prefix(old) for old in udf_output.keys()})
for new_col in udf_output.keys():
new_df[new_col] = udf_output[new_col]
except Exception as e: # noqa
Expand Down Expand Up @@ -271,6 +274,7 @@ def register_multioutput_udf(
prefix: Optional[str] = None,
namespace: Optional[str] = None,
schema_name: str = "",
no_prefix: bool = False,
) -> Callable[[Any], Any]:
"""
Decorator to easily configure UDFs for your data set. Decorate your UDF
Expand All @@ -294,7 +298,10 @@ def decorator_register(func):
global _multicolumn_udfs
name = udf_name or func.__name__
name = f"{namespace}.{name}" if namespace else name
output_prefix = prefix if prefix else name
if no_prefix:
output_prefix = None
else:
output_prefix = prefix if prefix else name
_multicolumn_udfs[schema_name].append(UdfSpec(col_names, prefix=output_prefix, udf=func, name=name))
return func

Expand Down

0 comments on commit f8ff9a7

Please sign in to comment.