diff --git a/sdmetrics/single_table/data_augmentation/binary_classifier_precision_efficacy.py b/sdmetrics/single_table/data_augmentation/binary_classifier_precision_efficacy.py index 382e6615..879b04b2 100644 --- a/sdmetrics/single_table/data_augmentation/binary_classifier_precision_efficacy.py +++ b/sdmetrics/single_table/data_augmentation/binary_classifier_precision_efficacy.py @@ -18,7 +18,7 @@ def compute_breakdown( metadata, prediction_column_name, minority_class_label, - classifier='xgboost', + classifier='XGBoost', fixed_recall_value=0.9, ): """Compute the score breakdown of the metric.""" diff --git a/tests/integration/single_table/data_augmentation/test_binary_classifier_precision_efficacy.py b/tests/integration/single_table/data_augmentation/test_binary_classifier_precision_efficacy.py index 7459ee4a..836d978d 100644 --- a/tests/integration/single_table/data_augmentation/test_binary_classifier_precision_efficacy.py +++ b/tests/integration/single_table/data_augmentation/test_binary_classifier_precision_efficacy.py @@ -25,7 +25,6 @@ def test_end_to_end(self): metadata=metadata, prediction_column_name='gender', minority_class_label='F', - classifier='XGBoost', fixed_recall_value=0.8, ) diff --git a/tests/integration/single_table/data_augmentation/test_binary_classifier_recall_efficacy.py b/tests/integration/single_table/data_augmentation/test_binary_classifier_recall_efficacy.py index 66c40dee..60bf2a54 100644 --- a/tests/integration/single_table/data_augmentation/test_binary_classifier_recall_efficacy.py +++ b/tests/integration/single_table/data_augmentation/test_binary_classifier_recall_efficacy.py @@ -43,7 +43,6 @@ def test_end_to_end(self): metadata=metadata, prediction_column_name='gender', minority_class_label='F', - classifier='XGBoost', fixed_precision_value=0.8, ) diff --git a/tests/unit/single_table/data_augmentation/test_binary_classifier_precision_efficacy.py b/tests/unit/single_table/data_augmentation/test_binary_classifier_precision_efficacy.py index aad62e37..a137c3cd 100644 --- a/tests/unit/single_table/data_augmentation/test_binary_classifier_precision_efficacy.py +++ b/tests/unit/single_table/data_augmentation/test_binary_classifier_precision_efficacy.py @@ -32,8 +32,6 @@ def test_compute_breakdown(self, mock_compute_breakdown): metadata = {} prediction_column_name = 'prediction_column_name' minority_class_label = 'minority_class_label' - classifier = 'XGBoost' - fixed_recall_value = 0.8 # Run BinaryClassifierPrecisionEfficacy.compute_breakdown( @@ -43,8 +41,6 @@ def test_compute_breakdown(self, mock_compute_breakdown): metadata=metadata, prediction_column_name=prediction_column_name, minority_class_label=minority_class_label, - classifier=classifier, - fixed_recall_value=fixed_recall_value, ) # Assert @@ -55,8 +51,8 @@ def test_compute_breakdown(self, mock_compute_breakdown): metadata, prediction_column_name, minority_class_label, - classifier, - fixed_recall_value, + 'XGBoost', + 0.9, ) @patch('sdmetrics.single_table.data_augmentation.base.BaseDataAugmentationMetric.compute') diff --git a/tests/unit/single_table/data_augmentation/test_binary_classifier_recall_efficacy.py b/tests/unit/single_table/data_augmentation/test_binary_classifier_recall_efficacy.py index 350dc082..7e51a058 100644 --- a/tests/unit/single_table/data_augmentation/test_binary_classifier_recall_efficacy.py +++ b/tests/unit/single_table/data_augmentation/test_binary_classifier_recall_efficacy.py @@ -32,8 +32,6 @@ def test_compute_breakdown(self, mock_compute_breakdown): metadata = {} prediction_column_name = 'prediction_column_name' minority_class_label = 'minority_class_label' - classifier = 'XGBoost' - fixed_precision_value = 0.8 # Run BinaryClassifierRecallEfficacy.compute_breakdown( @@ -43,8 +41,6 @@ def test_compute_breakdown(self, mock_compute_breakdown): metadata=metadata, prediction_column_name=prediction_column_name, minority_class_label=minority_class_label, - classifier=classifier, - fixed_precision_value=fixed_precision_value, ) # Assert @@ -55,8 +51,8 @@ def test_compute_breakdown(self, mock_compute_breakdown): metadata, prediction_column_name, minority_class_label, - classifier, - fixed_precision_value, + 'XGBoost', + 0.9, ) @patch('sdmetrics.single_table.data_augmentation.base.BaseDataAugmentationMetric.compute')