Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH Array API support for f1_score and multilabel_confusion_matrix #27369

Open
wants to merge 26 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
9bcbc4c
ENH Array API support for f1_score
OmarManzoor Sep 14, 2023
87b65b9
Merge branch 'main' into f1_array_api
OmarManzoor May 17, 2024
17457b9
Merge branch 'main' into f1_array_api
OmarManzoor May 17, 2024
c80617f
Add array api support for f1_score
OmarManzoor May 20, 2024
649ce17
Add changelog
OmarManzoor May 20, 2024
8f0db56
Merge branch 'main' into f1_array_api
OmarManzoor May 20, 2024
6a02fcd
Fix sample weights in _bincount
OmarManzoor May 20, 2024
c150c9c
Add some fixes
OmarManzoor May 20, 2024
a01f2d7
Correct and add tests for nanmean
OmarManzoor May 20, 2024
aa2f521
Add options for testing with various average values
OmarManzoor May 20, 2024
75c7d5a
Use reshape when creating arrays in micro average
OmarManzoor May 20, 2024
8b21b51
Add LabelEncoder and f1_score in array_api.rst
OmarManzoor May 27, 2024
bc8c2df
Merge branch 'main' into f1_array_api
OmarManzoor May 27, 2024
ef33cf6
Merge branch 'main' into f1_array_api
ogrisel Jun 5, 2024
91ab0d5
Merge branch 'main' into f1_array_api
OmarManzoor Jun 6, 2024
696e65b
Update: PR suggestions
OmarManzoor Jun 6, 2024
d0b647b
Use xp.reshape with (1,)
OmarManzoor Jun 6, 2024
842e269
Simplify count in _nanmean
OmarManzoor Jun 6, 2024
6e9596e
Merge branch 'main' into f1_array_api
OmarManzoor Jun 6, 2024
5cd9a11
Merge branch 'main' into f1_array_api
OmarManzoor Jun 7, 2024
2c3cc32
Merge branch 'main' into f1_array_api
OmarManzoor Jun 14, 2024
5c73766
Merge branch 'main' into f1_array_api
OmarManzoor Jun 25, 2024
0df1e0f
Add multilabel confusion metrics as it seems to work
OmarManzoor Jun 25, 2024
3d42289
Merge branch 'main' into f1_array_api
OmarManzoor Jul 2, 2024
78e4f31
Handle multi-label case
OmarManzoor Jul 2, 2024
74ccf6a
Fix commented tests
OmarManzoor Jul 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/modules/array_api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ Estimators
- :class:`linear_model.Ridge` (with `solver="svd"`)
- :class:`discriminant_analysis.LinearDiscriminantAnalysis` (with `solver="svd"`)
- :class:`preprocessing.KernelCenterer`
- :class:`preprocessing.LabelEncoder`
- :class:`preprocessing.MaxAbsScaler`
- :class:`preprocessing.MinMaxScaler`
- :class:`preprocessing.Normalizer`
Expand All @@ -102,6 +103,7 @@ Metrics
-------

- :func:`sklearn.metrics.accuracy_score`
- :func:`sklearn.metrics.f1_score`
- :func:`sklearn.metrics.mean_absolute_error`
- :func:`sklearn.metrics.mean_tweedie_deviance`
- :func:`sklearn.metrics.pairwise.cosine_similarity`
Expand Down
2 changes: 1 addition & 1 deletion doc/whats_new/v1.6.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ See :ref:`array_api` for more details.
:pr:`28106` by :user:`Thomas Li <lithomas1>`;
- :func:`sklearn.metrics.mean_absolute_error` :pr:`27736` by :user:`Edoardo Abati <EdAbati>`;
- :func:`sklearn.metrics.pairwise.cosine_similarity` :pr:`29014` by :user:`Edoardo Abati <EdAbati>`.

- :func:`sklearn.metrics.f1_score` :pr:`27369` by :user:`Omar Salman <OmarManzoor>`.

**Classes:**

Expand Down
87 changes: 51 additions & 36 deletions sklearn/metrics/_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,12 @@
)
from ..utils._array_api import (
_average,
_bincount,
_find_matching_floating_dtype,
_searchsorted,
_setdiff1d,
_union1d,
device,
get_namespace,
)
from ..utils._param_validation import (
Expand Down Expand Up @@ -514,9 +519,11 @@ def multilabel_confusion_matrix(
[[2, 1],
[1, 2]]])
"""
xp, _ = get_namespace(y_true, y_pred)
device_ = device(y_true, y_pred)
y_type, y_true, y_pred = _check_targets(y_true, y_pred)
if sample_weight is not None:
sample_weight = column_or_1d(sample_weight)
sample_weight = xp.asarray(column_or_1d(sample_weight), device=device_)
check_consistent_length(y_true, y_pred, sample_weight)

if y_type not in ("binary", "multiclass", "multilabel-indicator"):
Expand All @@ -527,9 +534,11 @@ def multilabel_confusion_matrix(
labels = present_labels
n_labels = None
else:
n_labels = len(labels)
labels = np.hstack(
[labels, np.setdiff1d(present_labels, labels, assume_unique=True)]
labels = xp.asarray(labels, device=device_)
n_labels = labels.shape[0]
labels = xp.concat(
[labels, _setdiff1d(present_labels, labels, assume_unique=True, xp=xp)],
axis=-1,
)

if y_true.ndim == 1:
Expand All @@ -549,29 +558,33 @@ def multilabel_confusion_matrix(
tp = y_true == y_pred
tp_bins = y_true[tp]
if sample_weight is not None:
tp_bins_weights = np.asarray(sample_weight)[tp]
tp_bins_weights = sample_weight[tp]
else:
tp_bins_weights = None

if len(tp_bins):
tp_sum = np.bincount(
tp_bins, weights=tp_bins_weights, minlength=len(labels)
if tp_bins.shape[0]:
tp_sum = _bincount(
xp, tp_bins, weights=tp_bins_weights, minlength=labels.shape[0]
)
else:
# Pathological case
true_sum = pred_sum = tp_sum = np.zeros(len(labels))
if len(y_pred):
pred_sum = np.bincount(y_pred, weights=sample_weight, minlength=len(labels))
if len(y_true):
true_sum = np.bincount(y_true, weights=sample_weight, minlength=len(labels))
true_sum = pred_sum = tp_sum = xp.zeros(labels.shape[0])
if y_pred.shape[0]:
pred_sum = _bincount(
xp, y_pred, weights=sample_weight, minlength=labels.shape[0]
)
if y_true.shape[0]:
true_sum = _bincount(
xp, y_true, weights=sample_weight, minlength=labels.shape[0]
)

# Retain only selected labels
indices = np.searchsorted(sorted_labels, labels[:n_labels])
tp_sum = tp_sum[indices]
true_sum = true_sum[indices]
pred_sum = pred_sum[indices]
indices = _searchsorted(xp, sorted_labels, labels[:n_labels])
tp_sum = xp.take(tp_sum, indices, axis=0)
true_sum = xp.take(true_sum, indices, axis=0)
pred_sum = xp.take(pred_sum, indices, axis=0)

else:
else: # y_true is a 2D sparse matrix of one-hot multi-label indicators
sum_axis = 1 if samplewise else 0

# All labels are index integers for multilabel.
Expand Down Expand Up @@ -607,19 +620,18 @@ def multilabel_confusion_matrix(
tp = tp_sum

if sample_weight is not None and samplewise:
sample_weight = np.array(sample_weight)
tp = np.array(tp)
fp = np.array(fp)
fn = np.array(fn)
tp = xp.asarray(tp)
fp = xp.asarray(fp)
fn = xp.asarray(fn)
tn = sample_weight * y_true.shape[1] - tp - fp - fn
elif sample_weight is not None:
tn = sum(sample_weight) - tp - fp - fn
tn = xp.sum(sample_weight) - tp - fp - fn
elif samplewise:
tn = y_true.shape[1] - tp - fp - fn
else:
tn = y_true.shape[0] - tp - fp - fn

return np.array([tn, fp, fn, tp]).T.reshape(-1, 2, 2)
return xp.reshape(xp.stack([tn, fp, fn, tp]).T, (-1, 2, 2))


@validate_params(
Expand Down Expand Up @@ -1494,12 +1506,14 @@ def _prf_divide(
The metric, modifier and average arguments are used only for determining
an appropriate warning.
"""
mask = denominator == 0.0
denominator = denominator.copy()
xp, _ = get_namespace(numerator, denominator)
dtype_float = _find_matching_floating_dtype(numerator, denominator, xp=xp)
mask = denominator == 0
denominator = xp.asarray(denominator, copy=True, dtype=dtype_float)
denominator[mask] = 1 # avoid infs/nans
result = numerator / denominator
result = xp.asarray(numerator, dtype=dtype_float) / denominator

if not np.any(mask):
if not xp.any(mask):
return result

# set those with 0 denominator to `zero_division`, and 0 when "warn"
Expand Down Expand Up @@ -1547,7 +1561,7 @@ def _check_set_wise_labels(y_true, y_pred, average, labels, pos_label):
y_type, y_true, y_pred = _check_targets(y_true, y_pred)
# Convert to Python primitive type to avoid NumPy type / Python str
# comparison. See https://github.com/numpy/numpy/issues/6784
present_labels = unique_labels(y_true, y_pred).tolist()
present_labels = list(unique_labels(y_true, y_pred))
if average == "binary":
if y_type == "binary":
if pos_label not in present_labels:
Expand Down Expand Up @@ -1787,10 +1801,11 @@ def precision_recall_fscore_support(
pred_sum = tp_sum + MCM[:, 0, 1]
true_sum = tp_sum + MCM[:, 1, 0]

xp, _ = get_namespace(y_true, y_pred)
if average == "micro":
tp_sum = np.array([tp_sum.sum()])
pred_sum = np.array([pred_sum.sum()])
true_sum = np.array([true_sum.sum()])
tp_sum = xp.reshape(xp.sum(tp_sum), shape=(1,))
pred_sum = xp.reshape(xp.sum(pred_sum), shape=(1,))
true_sum = xp.reshape(xp.sum(true_sum), shape=(1,))

# Finally, we have all our sufficient statistics. Divide! #
beta2 = beta**2
Expand Down Expand Up @@ -1833,10 +1848,10 @@ def precision_recall_fscore_support(
weights = None

if average is not None:
assert average != "binary" or len(precision) == 1
precision = _nanaverage(precision, weights=weights)
recall = _nanaverage(recall, weights=weights)
f_score = _nanaverage(f_score, weights=weights)
assert average != "binary" or precision.shape[0] == 1
precision = float(_nanaverage(precision, weights=weights))
recall = float(_nanaverage(recall, weights=weights))
f_score = float(_nanaverage(f_score, weights=weights))
true_sum = None # return no support

return precision, recall, f_score, true_sum
Expand Down
50 changes: 31 additions & 19 deletions sklearn/metrics/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1803,27 +1803,35 @@ def check_array_api_multiclass_classification_metric(
y_true_np = np.array([0, 1, 2, 3])
y_pred_np = np.array([0, 1, 0, 2])

check_array_api_metric(
metric,
array_namespace,
device,
dtype_name,
a_np=y_true_np,
b_np=y_pred_np,
sample_weight=None,
)
metric_kwargs_combinations = [{}]
if "average" in signature(metric).parameters:
average_options = ("micro", "macro", "weighted")
metric_kwargs_combinations = [{"average": option} for option in average_options]

sample_weight = np.array([0.0, 0.1, 2.0, 1.0], dtype=dtype_name)
for metric_kwargs in metric_kwargs_combinations:
check_array_api_metric(
metric,
array_namespace,
device,
dtype_name,
a_np=y_true_np,
b_np=y_pred_np,
sample_weight=None,
**metric_kwargs,
)

check_array_api_metric(
metric,
array_namespace,
device,
dtype_name,
a_np=y_true_np,
b_np=y_pred_np,
sample_weight=sample_weight,
)
sample_weight = np.array([0.0, 0.1, 2.0, 1.0], dtype=dtype_name)

check_array_api_metric(
metric,
array_namespace,
device,
dtype_name,
a_np=y_true_np,
b_np=y_pred_np,
sample_weight=sample_weight,
**metric_kwargs,
)


def check_array_api_regression_metric(metric, array_namespace, device, dtype_name):
Expand Down Expand Up @@ -1924,6 +1932,10 @@ def check_array_api_metric_pairwise(metric, array_namespace, device, dtype_name)
check_array_api_binary_classification_metric,
check_array_api_multiclass_classification_metric,
],
f1_score: [
check_array_api_binary_classification_metric,
check_array_api_multiclass_classification_metric,
],
zero_one_loss: [
check_array_api_binary_classification_metric,
check_array_api_multiclass_classification_metric,
Expand Down
35 changes: 35 additions & 0 deletions sklearn/utils/_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,6 +735,26 @@ def _nanmax(X, axis=None, xp=None):
return X


def _nanmean(X, axis=None, xp=None):
# TODO: refactor once nan-aware reductions are standardized:
# https://github.com/data-apis/array-api/issues/621
xp, _ = get_namespace(X, xp=xp)
if _is_numpy_namespace(xp):
return xp.asarray(numpy.nanmean(X, axis=axis))
else:
mask = xp.isnan(X)
total = xp.sum(xp.where(mask, xp.asarray(0.0, device=device(X)), X), axis=axis)
count = xp.sum(
xp.where(
mask,
xp.asarray(0.0, device=device(X)),
xp.asarray(1.0, device=device(X)),
),
axis=axis,
)
OmarManzoor marked this conversation as resolved.
Show resolved Hide resolved
return total / count


def _asarray_with_order(
array, dtype=None, order=None, copy=None, *, xp=None, device=None
):
Expand Down Expand Up @@ -967,3 +987,18 @@ def _in1d(ar1, ar2, xp, assume_unique=False, invert=False):
return ret[: ar1.shape[0]]
else:
return xp.take(ret, rev_idx, axis=0)


def _bincount(xp, array, weights=None, minlength=None):
OmarManzoor marked this conversation as resolved.
Show resolved Hide resolved
# TODO: update if bincount is ever adopted in a future version of the standard:
# https://github.com/data-apis/array-api/issues/812
if hasattr(xp, "bincount"):
return xp.bincount(array, weights=weights, minlength=minlength)

array_np = _convert_to_numpy(array, xp=xp)
if weights is not None:
weights_np = _convert_to_numpy(weights, xp=xp)
else:
weights_np = None
bin_out = numpy.bincount(array_np, weights=weights_np, minlength=minlength)
return xp.asarray(bin_out, device=device(array))
22 changes: 11 additions & 11 deletions sklearn/utils/extmath.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from ..utils._param_validation import Interval, StrOptions, validate_params
from ..utils.deprecation import deprecated
from ._array_api import _is_numpy_namespace, device, get_namespace
from ._array_api import _average, _is_numpy_namespace, _nanmean, device, get_namespace
from .sparsefuncs_fast import csr_row_norms
from .validation import check_array, check_random_state

Expand Down Expand Up @@ -1264,24 +1264,24 @@ def _nanaverage(a, weights=None):
that :func:`np.nan` values are ignored from the average and weights can
be passed. Note that when possible, we delegate to the prime methods.
"""
xp, _ = get_namespace(a)
if a.shape[0] == 0:
return xp.nan

if len(a) == 0:
return np.nan

mask = np.isnan(a)
if mask.all():
return np.nan
mask = xp.isnan(a)
if xp.all(mask):
return xp.nan

if weights is None:
return np.nanmean(a)
return _nanmean(a)

weights = np.asarray(weights)
weights = xp.asarray(weights)
a, weights = a[~mask], weights[~mask]
try:
return np.average(a, weights=weights)
return _average(a, weights=weights)
except ZeroDivisionError:
# this is when all weights are zero, then ignore them
return np.average(a)
return _average(a)


def safe_sqr(X, *, copy=True):
Expand Down
14 changes: 14 additions & 0 deletions sklearn/utils/tests/test_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
_is_numpy_namespace,
_isin,
_nanmax,
_nanmean,
_nanmin,
_NumPyAPIWrapper,
_ravel,
Expand Down Expand Up @@ -340,6 +341,19 @@ def __init__(self, device_name):
partial(_nanmax, axis=1),
[3.0, numpy.nan, 6.0],
),
([1, 2, numpy.nan], _nanmean, 1.5),
([1, -2, -numpy.nan], _nanmean, -0.5),
([-numpy.inf, -numpy.inf], _nanmean, -numpy.inf),
(
[[1, 2, 3], [numpy.nan, numpy.nan, numpy.nan], [4, 5, 6.0]],
partial(_nanmean, axis=0),
[2.5, 3.5, 4.5],
),
(
[[1, 2, 3], [numpy.nan, numpy.nan, numpy.nan], [4, 5, 6.0]],
partial(_nanmean, axis=1),
[2.0, numpy.nan, 5.0],
),
],
)
def test_nan_reductions(library, X, reduction, expected):
Expand Down
Loading