Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: [python-package] ensure predict() always returns an array #6348

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
remove special case from dask test
  • Loading branch information
jameslamb committed Mar 1, 2024
commit 5a45947cee0f7f8e325f0303f9f8186634b50da6
29 changes: 0 additions & 29 deletions tests/python_package_test/test_dask.py
Original file line number Diff line number Diff line change
@@ -343,35 +343,6 @@ def test_classifier_pred_contrib(output, task, cluster):
else:
expected_num_cols = (num_features + 1) * num_classes

# in the special case of multi-class classification using scipy sparse matrices,
# the output of `.predict(..., pred_contrib=True)` is a list of sparse matrices (one per class)
#
# since that case is so different than all other cases, check the relevant things here
# and then return early
if output.startswith("scipy") and task == "multiclass-classification":
if output == "scipy_csr_matrix":
expected_type = csr_matrix
elif output == "scipy_csc_matrix":
expected_type = csc_matrix
else:
raise ValueError(f"Unrecognized output type: {output}")
assert isinstance(preds_with_contrib, list)
assert all(isinstance(arr, da.Array) for arr in preds_with_contrib)
assert all(isinstance(arr._meta, expected_type) for arr in preds_with_contrib)
assert len(preds_with_contrib) == num_classes
assert len(preds_with_contrib) == len(local_preds_with_contrib)
for i in range(num_classes):
computed_preds = preds_with_contrib[i].compute()
assert isinstance(computed_preds, expected_type)
assert computed_preds.shape[1] == num_classes
assert computed_preds.shape == local_preds_with_contrib[i].shape
assert len(np.unique(computed_preds[:, -1])) == 1
# raw scores will probably be different, but at least check that all predicted classes are the same
pred_classes = np.argmax(computed_preds.toarray(), axis=1)
local_pred_classes = np.argmax(local_preds_with_contrib[i].toarray(), axis=1)
np.testing.assert_array_equal(pred_classes, local_pred_classes)
return

preds_with_contrib = preds_with_contrib.compute()
if output.startswith("scipy"):
preds_with_contrib = preds_with_contrib.toarray()
Loading
Oops, something went wrong.