From b658e4805be9515ba9059b0d18cdddbe6d24d19c Mon Sep 17 00:00:00 2001 From: ROMEEZHOU Date: Thu, 6 Apr 2023 00:05:21 +0800 Subject: [PATCH] MAINT Parameters validation for sklearn.metrics.mean_poisson_deviance --- sklearn/metrics/_regression.py | 7 +++++++ sklearn/tests/test_public_functions.py | 1 + 2 files changed, 8 insertions(+) diff --git a/sklearn/metrics/_regression.py b/sklearn/metrics/_regression.py index 5d84ea7b89966..377c3f8c467cf 100644 --- a/sklearn/metrics/_regression.py +++ b/sklearn/metrics/_regression.py @@ -1172,6 +1172,13 @@ def mean_tweedie_deviance(y_true, y_pred, *, sample_weight=None, power=0): ) +@validate_params( + { + "y_true": ["array-like"], + "y_pred": ["array-like"], + "sample_weight": ["array-like", None], + } +) def mean_poisson_deviance(y_true, y_pred, *, sample_weight=None): """Mean Poisson deviance regression loss. diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index 55c33295e1d3d..fc9745e1d94a2 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -199,6 +199,7 @@ def _check_function_param_validation( "sklearn.metrics.mean_absolute_percentage_error", "sklearn.metrics.mean_gamma_deviance", "sklearn.metrics.mean_pinball_loss", + "sklearn.metrics.mean_poisson_deviance", "sklearn.metrics.mean_squared_error", "sklearn.metrics.mean_squared_log_error", "sklearn.metrics.mean_tweedie_deviance",