Skip to content

Commit

Permalink
Support numeric, boolean, and string keyword arguments to class metho…
Browse files Browse the repository at this point in the history
…ds during CPU dispatching (#5236)

This PR:
- Updates the CPU dispatching logic to support keyword arguments that are numeric, boolean, and strings (i.e., things that can't be coerced to cuml arrays). It doesn't support passing sequences, as I believe this use case isn't necessary
- Makes a minor update to the tests to test this dispatching in principle. We may want to expand the testing of keyword arguments in general, but as this is likely worth a broader discussion/test expansion I thought it might be out of scope for this small PR


Closes #5218

This is a replacement for #5223

Authors:
  - Nick Becker (https://github.com/beckernick)
  - Dante Gama Dessavre (https://github.com/dantegd)

Approvers:
  - Dante Gama Dessavre (https://github.com/dantegd)

URL: #5236
  • Loading branch information
beckernick committed Apr 3, 2023
1 parent f116967 commit 4c6b880
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 6 deletions.
13 changes: 11 additions & 2 deletions python/cuml/internals/base.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import os
import inspect
import numbers
from importlib import import_module
from cuml.internals.safe_imports import cpu_only_import
np = cpu_only_import('numpy')
Expand All @@ -36,7 +37,8 @@ from cuml.internals.device_type import DeviceType
from cuml.internals.input_utils import (
determine_array_type,
input_to_cuml_array,
input_to_host_array
input_to_host_array,
is_array_like
)
from cuml.internals.memory_utils import determine_array_memtype
from cuml.internals.mem_type import MemoryType
Expand Down Expand Up @@ -626,7 +628,14 @@ class UniversalBase(Base):
# put all the kwargs on host
new_kwargs = dict()
for kw, arg in kwargs.items():
new_kwargs[kw] = input_to_host_array(arg)[0]
# if array-like, ensure array-like is on the host
if is_array_like(arg):
new_kwargs[kw] = input_to_host_array(arg)[0]
# if Real or string, pass as is
elif isinstance(arg, (numbers.Real, str)):
new_kwargs[kw] = arg
else:
raise ValueError(f"Unable to process argument {kw}")
return new_args, new_kwargs

def dispatch_func(self, func_name, gpu_func, *args, **kwargs):
Expand Down
15 changes: 11 additions & 4 deletions python/cuml/tests/test_device_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,7 @@ def tsvd_test_data(request):
"input_type": ["numpy", "dataframe", "cupy", "cudf", "numba"],
"metric": ["euclidean", "cosine"],
"n_neighbors": [3, 8],
"return_distance": [True],
}
)
)
Expand All @@ -519,6 +520,7 @@ def nn_test_data(request):
"metric": request.param["metric"],
"n_neighbors": request.param["n_neighbors"],
}
infer_func_kwargs = {"return_distance": request.param["return_distance"]}

sk_model = skNearestNeighbors(**kwargs)
sk_model.fit(X_train_blob)
Expand All @@ -529,6 +531,7 @@ def nn_test_data(request):
"cuEstimator": NearestNeighbors,
"kwargs": kwargs,
"infer_func": "kneighbors",
"infer_func_kwargs": infer_func_kwargs,
"assert_func": check_nn,
"X_train": to_output_type(X_train_blob, input_type),
"X_test": to_output_type(X_test_blob, input_type),
Expand Down Expand Up @@ -563,7 +566,8 @@ def test_train_cpu_infer_cpu(test_data):
else:
model.fit(test_data["X_train"])
infer_func = getattr(model, test_data["infer_func"])
cuml_output = infer_func(test_data["X_test"])
infer_func_kwargs = test_data.get("infer_func_kwargs", {})
cuml_output = infer_func(test_data["X_test"], **infer_func_kwargs)

assert_func = test_data["assert_func"]
assert_func(cuml_output, test_data)
Expand All @@ -582,7 +586,8 @@ def test_train_gpu_infer_cpu(test_data):
model.fit(test_data["X_train"])
with using_device_type("cpu"):
infer_func = getattr(model, test_data["infer_func"])
cuml_output = infer_func(test_data["X_test"])
infer_func_kwargs = test_data.get("infer_func_kwargs", {})
cuml_output = infer_func(test_data["X_test"], **infer_func_kwargs)

assert_func = test_data["assert_func"]
assert_func(cuml_output, test_data)
Expand All @@ -598,7 +603,8 @@ def test_train_cpu_infer_gpu(test_data):
model.fit(test_data["X_train"])
with using_device_type("gpu"):
infer_func = getattr(model, test_data["infer_func"])
cuml_output = infer_func(test_data["X_test"])
infer_func_kwargs = test_data.get("infer_func_kwargs", {})
cuml_output = infer_func(test_data["X_test"], **infer_func_kwargs)

assert_func = test_data["assert_func"]
assert_func(cuml_output, test_data)
Expand All @@ -613,7 +619,8 @@ def test_train_gpu_infer_gpu(test_data):
else:
model.fit(test_data["X_train"])
infer_func = getattr(model, test_data["infer_func"])
cuml_output = infer_func(test_data["X_test"])
infer_func_kwargs = test_data.get("infer_func_kwargs", {})
cuml_output = infer_func(test_data["X_test"], **infer_func_kwargs)

assert_func = test_data["assert_func"]
assert_func(cuml_output, test_data)
Expand Down

0 comments on commit 4c6b880

Please sign in to comment.