diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index 9e66c905..cd6e97ff 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -27,10 +27,15 @@ jobs: uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} + - name: Install libomp (macOS only) + if: matrix.os == 'macos-latest' + run: | + brew install libomp + echo 'export DYLD_LIBRARY_PATH=$(brew --prefix libomp)/lib:$DYLD_LIBRARY_PATH' >> $GITHUB_ENV - name: Install dependencies run: | python -m pip install --upgrade pip - python -m pip install invoke .[pomegranate,test] + python -m pip install invoke .[pomegranate,xgboost,test] - name: Run integration tests run: invoke integration diff --git a/.github/workflows/minimum.yml b/.github/workflows/minimum.yml index a664cb93..bd8c1863 100644 --- a/.github/workflows/minimum.yml +++ b/.github/workflows/minimum.yml @@ -27,9 +27,14 @@ jobs: uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} + - name: Install libomp (macOS only) + if: matrix.os == 'macos-latest' + run: | + brew install libomp + echo 'export DYLD_LIBRARY_PATH=$(brew --prefix libomp)/lib:$DYLD_LIBRARY_PATH' >> $GITHUB_ENV - name: Install dependencies run: | python -m pip install --upgrade pip - python -m pip install invoke .[test] + python -m pip install invoke .[test,xgboost] - name: Test with minimum versions run: invoke minimum diff --git a/.github/workflows/unit.yml b/.github/workflows/unit.yml index d0873c59..ed1aa264 100644 --- a/.github/workflows/unit.yml +++ b/.github/workflows/unit.yml @@ -27,10 +27,15 @@ jobs: uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} + - name: Install libomp (macOS only) + if: matrix.os == 'macos-latest' + run: | + brew install libomp + echo 'export DYLD_LIBRARY_PATH=$(brew --prefix libomp)/lib:$DYLD_LIBRARY_PATH' >> $GITHUB_ENV - name: Install dependencies run: | python -m pip install --upgrade pip - python -m pip install invoke .[test] + python -m pip install invoke .[test,xgboost] - name: Run unit tests run: invoke unit diff --git a/pyproject.toml b/pyproject.toml index 542d7fef..114d8ab4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,6 +55,7 @@ torch = [ "torch>=2.2.0;python_version>='3.12'", ] pomegranate = ['pomegranate>=0.15,<1.0'] +xgboost = ['xgboost>=2.1.3'] test = [ 'sdmetrics[torch]', 'pytest>=6.2.5,<7', @@ -67,7 +68,7 @@ test = [ 'pytest-runner>=2.11.1', ] dev = [ - 'sdmetrics[test, torch]', + 'sdmetrics[test, xgboost, torch]', # general 'build>=1.0.0,<2', diff --git a/sdmetrics/single_table/data_augmentation/__init__.py b/sdmetrics/single_table/data_augmentation/__init__.py new file mode 100644 index 00000000..1b65bbbc --- /dev/null +++ b/sdmetrics/single_table/data_augmentation/__init__.py @@ -0,0 +1,7 @@ +"""Data Augmentation Metric for single table datasets.""" + +from sdmetrics.single_table.data_augmentation.binary_classifier_precision_efficacy import ( + BinaryClassifierPrecisionEfficacy, +) + +__all__ = ['BinaryClassifierPrecisionEfficacy'] diff --git a/sdmetrics/single_table/data_augmentation/base.py b/sdmetrics/single_table/data_augmentation/base.py new file mode 100644 index 00000000..9a6d6eaa --- /dev/null +++ b/sdmetrics/single_table/data_augmentation/base.py @@ -0,0 +1,288 @@ +"""Base class for Efficacy metrics for single table datasets.""" + +from copy import deepcopy + +import numpy as np +import pandas as pd +from sklearn.metrics import confusion_matrix, precision_recall_curve, precision_score, recall_score +from xgboost import XGBClassifier + +from sdmetrics.goal import Goal +from sdmetrics.single_table.base import SingleTableMetric +from sdmetrics.single_table.data_augmentation.utils import _validate_inputs + +METRIC_NAME_TO_METHOD = {'recall': recall_score, 'precision': precision_score} + + +class ClassifierTrainer: + """Class to train a classifier model.""" + + def __init__( + self, + prediction_column_name, + minority_class_label, + classifier, + fixed_value, + metric_name, + ): + self.prediction_column_name = prediction_column_name + self.minority_class_label = minority_class_label + self.fixed_value = fixed_value + self.metric_name = metric_name + self._classifier_name = classifier + self._classifier = XGBClassifier(enable_categorical=True) + self._metric_to_fix = 'recall' if metric_name == 'precision' else 'precision' + self._metric_method = METRIC_NAME_TO_METHOD[self._metric_to_fix] + + def train_model(self, train_data): + """Train the classifier model.""" + train_target = train_data.pop(self.prediction_column_name) + self._classifier.fit(train_data, train_target) + self._best_threshold = self.get_best_threshold(train_data, train_target) + probabilities = self._classifier.predict_proba(train_data)[:, 1] + predictions = (probabilities >= self._best_threshold).astype(int) + + return self._metric_method(train_target, predictions) + + def get_best_threshold(self, train_data, train_target): + """Find the best threshold for the classifier model.""" + target_probabilities = self._classifier.predict_proba(train_data)[:, 1] + precision, recall, thresholds = precision_recall_curve(train_target, target_probabilities) + metric_map = {'precision': precision, 'recall': recall} + metric = metric_map[self._metric_to_fix] + valid_idx = np.where(metric >= self.fixed_value)[0] + if valid_idx.size: + best_idx = valid_idx[np.argmin(metric[valid_idx] - self.fixed_value)] + return thresholds[best_idx] if best_idx < len(thresholds) else 1.0 + + return 1.0 + + def compute_validation_scores(self, real_validation_data): + """Compute the validation scores.""" + real_validation_target = real_validation_data.pop(self.prediction_column_name) + predictions = self._classifier.predict_proba(real_validation_data)[:, 1] + predictions = (predictions >= self._best_threshold).astype(int) + recall = recall_score(real_validation_target, predictions) + precision = precision_score(real_validation_target, predictions) + conf_matrix = confusion_matrix(real_validation_target, predictions) + prediction_counts_validation = { + 'true_positive': int(conf_matrix[1, 1]), + 'false_positive': int(conf_matrix[0, 1]), + 'true_negative': int(conf_matrix[0, 0]), + 'false_negative': int(conf_matrix[1, 0]), + } + + return recall, precision, prediction_counts_validation + + def get_scores(self, training_table, validation_table): + """Get the scores of the metric.""" + training_table = deepcopy(training_table) + validation_table = deepcopy(validation_table) + training_score = self.train_model(training_table) + recall, precision, prediction_counts_validation = self.compute_validation_scores( + validation_table + ) + return { + f'{self._metric_to_fix}_score_training': training_score, + 'recall_score_validation': recall, + 'precision_score_validation': precision, + 'prediction_counts_validation': prediction_counts_validation, + } + + +class BaseDataAugmentationMetric(SingleTableMetric): + """Base class for Data Augmentation metrics for single table datasets.""" + + name = None + metric_name = None + goal = Goal.MAXIMIZE + min_value = 0.0 + max_value = 1.0 + + @classmethod + def _fit(cls, data, metadata, prediction_column_name): + """Fit preprocessing parameters.""" + discrete_columns = [] + datetime_columns = [] + for column, column_meta in metadata['columns'].items(): + if (column_meta['sdtype'] in ['categorical', 'boolean']) and ( + column != prediction_column_name + ): + discrete_columns.append(column) + elif column_meta['sdtype'] == 'datetime': + datetime_columns.append(column) + + return discrete_columns, datetime_columns + + @classmethod + def _transform( + cls, + tables, + discrete_columns, + datetime_columns, + prediction_column_name, + minority_class_label, + ): + """Transform by preprocessing the tables. + + Args: + tables (dict[str, pandas.DataFrame]): + Dict containing `real_training_data`, `synthetic_data` and `real_validation_data`. + """ + tables_result = {} + for table_name, table in tables.items(): + table = table.copy() + table[discrete_columns] = table[discrete_columns].astype('category') + table[datetime_columns] = table[datetime_columns].apply(pd.to_numeric) + table[prediction_column_name] = ( + table[prediction_column_name] == minority_class_label + ).astype(int) + tables_result[table_name] = table + + return tables_result + + @classmethod + def _fit_transform( + cls, + real_training_data, + synthetic_data, + real_validation_data, + metadata, + prediction_column_name, + minority_class_label, + ): + """Fit and transform the metric.""" + discrete_columns, datetime_columns = cls._fit( + real_training_data, metadata, prediction_column_name + ) + tables = { + 'real_training_data': real_training_data, + 'synthetic_data': synthetic_data, + 'real_validation_data': real_validation_data, + } + + return cls._transform( + tables, + discrete_columns, + datetime_columns, + prediction_column_name, + minority_class_label, + ) + + @classmethod + def compute_breakdown( + cls, + real_training_data, + synthetic_data, + real_validation_data, + metadata, + prediction_column_name, + minority_class_label, + classifier, + fixed_recall_value, + ): + """Compute the score breakdown of the metric.""" + _validate_inputs( + real_training_data, + synthetic_data, + real_validation_data, + metadata, + prediction_column_name, + minority_class_label, + classifier, + fixed_recall_value, + ) + preprocessed_tables = cls._fit_transform( + real_training_data, + synthetic_data, + real_validation_data, + metadata, + prediction_column_name, + minority_class_label, + ) + trainer = ClassifierTrainer( + prediction_column_name, + minority_class_label, + classifier, + fixed_recall_value, + cls.metric_name, + ) + metric_to_fix = 'recall' if cls.metric_name == 'precision' else 'precision' + result = { + 'real_data_baseline': trainer.get_scores( + preprocessed_tables['real_training_data'], + preprocessed_tables['real_validation_data'], + ), + 'augmented_data': trainer.get_scores( + pd.concat([ + preprocessed_tables['real_training_data'], + preprocessed_tables['synthetic_data'], + ]).reset_index(drop=True), + preprocessed_tables['real_validation_data'], + ), + 'parameters': { + 'prediction_column_name': trainer.prediction_column_name, + 'minority_class_label': trainer.minority_class_label, + 'classifier': trainer._classifier_name, + f'fixed_{metric_to_fix}_value': trainer.fixed_value, + }, + } + result['score'] = max( + 0, + ( + result['augmented_data'][f'{cls.metric_name}_score_validation'] + - result['real_data_baseline'][f'{cls.metric_name}_score_validation'] + ), + ) + return result + + @classmethod + def compute( + cls, + real_training_data, + synthetic_data, + real_validation_data, + metadata, + prediction_column_name, + minority_class_label, + classifier=None, + fixed_recall_value=0.9, + ): + """Compute the score of the metric. + + Args: + real_training_data (pandas.DataFrame): + The real training data. + synthetic_data (pandas.DataFrame): + The synthetic data. + real_validation_data (pandas.DataFrame): + The real validation data. + metadata (dict): + The metadata dictionary describing the table of data. + prediction_column_name (str): + The name of the column to be predicted. + minority_class_label (int): + The minority class label. + classifier (str): + The ML algorithm to use when building a Binary Classfication. + Supported options are ``XGBoost``. Defaults to ``XGBoost``. + fixed_recall_value (float): + A float in the range (0, 1.0) describing the value to fix for the recall when + building the Binary Classification model. Defaults to ``0.9``. + + Returns: + float: + The score of the metric. + """ + breakdown = cls.compute_breakdown( + real_training_data, + synthetic_data, + real_validation_data, + metadata, + prediction_column_name, + minority_class_label, + classifier, + fixed_recall_value, + ) + + return breakdown['score'] diff --git a/sdmetrics/single_table/data_augmentation/binary_classifier_precision_efficacy.py b/sdmetrics/single_table/data_augmentation/binary_classifier_precision_efficacy.py new file mode 100644 index 00000000..5c683aa1 --- /dev/null +++ b/sdmetrics/single_table/data_augmentation/binary_classifier_precision_efficacy.py @@ -0,0 +1,10 @@ +"""Binary classifier precision efficacy metric.""" + +from sdmetrics.single_table.data_augmentation.base import BaseDataAugmentationMetric + + +class BinaryClassifierPrecisionEfficacy(BaseDataAugmentationMetric): + """Binary classifier precision efficacy metric.""" + + name = 'Binary Classifier Precision Efficacy' + metric_name = 'precision' diff --git a/sdmetrics/single_table/data_augmentation/utils.py b/sdmetrics/single_table/data_augmentation/utils.py new file mode 100644 index 00000000..c8f0b34a --- /dev/null +++ b/sdmetrics/single_table/data_augmentation/utils.py @@ -0,0 +1,148 @@ +"""Utils method for data augmentation metrics.""" + +import pandas as pd + + +def _validate_tables(real_training_data, synthetic_data, real_validation_data): + """Validate the tables of the Data Augmentation metrics.""" + tables = [real_training_data, synthetic_data, real_validation_data] + if any(not isinstance(table, pd.DataFrame) for table in tables): + raise ValueError( + '`real_training_data`, `synthetic_data` and `real_validation_data` must be ' + 'pandas DataFrames.' + ) + + +def _validate_metadata(metadata): + """Validate the metadata of the Data Augmentation metrics.""" + if not isinstance(metadata, dict): + raise TypeError( + f"Expected a dictionary but received a '{type(metadata).__name__}' instead." + " For SDV metadata objects, please use the 'to_dict' function to convert it" + ' to a dictionary.' + ) + + +def _validate_prediction_column_name(prediction_column_name): + """Validate the prediction column name of the Data Augmentation metrics.""" + if not isinstance(prediction_column_name, str): + raise TypeError('`prediction_column_name` must be a string.') + + +def _validate_classifier(classifier): + """Validate the classifier of the Data Augmentation metrics.""" + if classifier is not None and not isinstance(classifier, str): + raise TypeError('`classifier` must be a string or None.') + + if classifier != 'XGBoost': + raise ValueError('Currently only `XGBoost` is supported as classifier.') + + +def _validate_fixed_recall_value(fixed_recall_value): + """Validate the fixed recall value of the Data Augmentation metrics.""" + if not isinstance(fixed_recall_value, (int, float)) or not (0 < fixed_recall_value < 1): + raise TypeError('`fixed_recall_value` must be a float in the range (0, 1).') + + +def _validate_parameters( + real_training_data, + synthetic_data, + real_validation_data, + metadata, + prediction_column_name, + classifier, + fixed_recall_value, +): + """Validate the parameters of the Data Augmentation metrics.""" + _validate_tables(real_training_data, synthetic_data, real_validation_data) + _validate_metadata(metadata) + _validate_prediction_column_name(prediction_column_name) + _validate_classifier(classifier) + _validate_fixed_recall_value(fixed_recall_value) + + +def _validate_data_and_metadata( + real_training_data, + synthetic_data, + real_validation_data, + metadata, + prediction_column_name, + minority_class_label, +): + """Validate the data and metadata of the Data Augmentation metrics.""" + if prediction_column_name not in metadata['columns']: + raise ValueError( + f'The column `{prediction_column_name}` is not described in the metadata.' + ' Please update your metadata.' + ) + + if metadata['columns'][prediction_column_name]['sdtype'] not in ('categorical', 'boolean'): + raise ValueError( + f'The column `{prediction_column_name}` must be either categorical or boolean.' + ' Please update your metadata.' + ) + + columns_match = ( + set(real_training_data.columns) + == set(synthetic_data.columns) + == set(real_validation_data.columns) + ) + data_metadata_mismatch = set(metadata['columns'].keys()) != set(real_training_data.columns) + if not columns_match or data_metadata_mismatch: + raise ValueError( + '`real_training_data`, `synthetic_data` and `real_validation_data` must have ' + 'the same columns and must match the columns described in the metadata.' + ) + + if minority_class_label not in real_training_data[prediction_column_name].unique(): + raise ValueError( + f'The value `{minority_class_label}` is not present in the column ' + f'`{prediction_column_name}` for the real training data.' + ) + + if minority_class_label not in real_validation_data[prediction_column_name].unique(): + raise ValueError( + f"The metric can't be computed because the value `{minority_class_label}` " + f'is not present in the column `{prediction_column_name}` for the real validation data.' + ' The `precision`and `recall` are undefined for this case.' + ) + + synthetic_labels = set(synthetic_data[prediction_column_name].unique()) + real_labels = set(real_training_data[prediction_column_name].unique()) + if not synthetic_labels.issubset(real_labels): + to_print = "', '".join(sorted(synthetic_labels - real_labels)) + raise ValueError( + f'The ``{prediction_column_name}`` column must have the same values in the real ' + 'and synthetic data. The following values are present in the synthetic data and' + f" not the real data: '{to_print}'" + ) + + +def _validate_inputs( + real_training_data, + synthetic_data, + real_validation_data, + metadata, + prediction_column_name, + minority_class_label, + classifier, + fixed_recall_value, +): + """Validate the inputs of the Data Augmentation metrics.""" + _validate_parameters( + real_training_data, + synthetic_data, + real_validation_data, + metadata, + prediction_column_name, + classifier, + fixed_recall_value, + ) + _validate_data_and_metadata( + real_training_data, + synthetic_data, + real_validation_data, + metadata, + prediction_column_name, + minority_class_label, + ) diff --git a/sdmetrics/visualization.py b/sdmetrics/visualization.py index 24427b5a..dac9f773 100644 --- a/sdmetrics/visualization.py +++ b/sdmetrics/visualization.py @@ -32,7 +32,7 @@ def wrapper(*args, **kwargs): ipython_interpreter = str(get_ipython()) if 'ZMQInteractiveShell' in ipython_interpreter and 'iframe' in renderers: # This means we are using jupyter notebook - pio.renderers.default = 'vscode' + pio.renderers.default = 'iframe' except Exception: pass diff --git a/tasks.py b/tasks.py index c7cd5a23..d42c8a61 100644 --- a/tasks.py +++ b/tasks.py @@ -30,8 +30,14 @@ def unit(c): @task -def integration(c): - c.run('python -m pytest ./tests/integration --reruns 5 --disable-warnings --cov=sdmetrics --cov-report=xml:./integration_cov.xml') +def integration(c, env=None): + env = env or {} + env.update({"OMP_NUM_THREADS": "1", "MKL_NUM_THREADS": "1"}) + + c.run( + 'python -m pytest ./tests/integration --reruns 5 --disable-warnings --cov=sdmetrics --cov-report=xml:./integration_cov.xml', + env=env + ) def _get_minimum_versions(dependencies, python_version): @@ -88,8 +94,7 @@ def minimum(c): install_minimum(c) check_dependencies(c) unit(c) - integration(c) - + integration(c, env={"OMP_NUM_THREADS": "1", "MKL_NUM_THREADS": "1"}) @task def readme(c): diff --git a/tests/integration/single_table/data_augmentation/__init__.py b/tests/integration/single_table/data_augmentation/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/integration/single_table/data_augmentation/test_binary_classifier_precision_efficacy.py b/tests/integration/single_table/data_augmentation/test_binary_classifier_precision_efficacy.py new file mode 100644 index 00000000..102d3f22 --- /dev/null +++ b/tests/integration/single_table/data_augmentation/test_binary_classifier_precision_efficacy.py @@ -0,0 +1,249 @@ +import re + +import numpy as np +import pytest + +from sdmetrics.demos import load_demo +from sdmetrics.single_table.data_augmentation import BinaryClassifierPrecisionEfficacy + + +class TestBinaryClassifierPrecisionEfficacy: + def test_end_to_end(self): + """Test the metric end-to-end.""" + # Setup + np.random.seed(0) + real_data, synthetic_data, metadata = load_demo(modality='single_table') + mask_validation = np.random.rand(len(real_data)) < 0.8 + real_training = real_data[mask_validation] + real_validation = real_data[~mask_validation] + + # Run + score_breakdown = BinaryClassifierPrecisionEfficacy.compute_breakdown( + real_training_data=real_training, + synthetic_data=synthetic_data, + real_validation_data=real_validation, + metadata=metadata, + prediction_column_name='gender', + minority_class_label='F', + classifier='XGBoost', + fixed_recall_value=0.8, + ) + + score = BinaryClassifierPrecisionEfficacy.compute( + real_training_data=real_training, + synthetic_data=synthetic_data, + real_validation_data=real_validation, + metadata=metadata, + prediction_column_name='gender', + minority_class_label='F', + classifier='XGBoost', + fixed_recall_value=0.8, + ) + + # Assert + expected_score_breakdown = { + 'real_data_baseline': { + 'recall_score_training': 0.8095238095238095, + 'recall_score_validation': 0.07692307692307693, + 'precision_score_validation': 1.0, + 'prediction_counts_validation': { + 'true_positive': 1, + 'false_positive': 0, + 'true_negative': 25, + 'false_negative': 12, + }, + }, + 'augmented_data': { + 'recall_score_training': 0.8057553956834532, + 'recall_score_validation': 0.0, + 'precision_score_validation': 0.0, + 'prediction_counts_validation': { + 'true_positive': 0, + 'false_positive': 2, + 'true_negative': 23, + 'false_negative': 13, + }, + }, + 'parameters': { + 'prediction_column_name': 'gender', + 'minority_class_label': 'F', + 'classifier': 'XGBoost', + 'fixed_recall_value': 0.8, + }, + 'score': 0, + } + assert np.isclose( + score_breakdown['real_data_baseline']['recall_score_training'], 0.8, atol=0.1 + ) + assert np.isclose( + score_breakdown['augmented_data']['recall_score_validation'], 0.1, atol=0.1 + ) + assert score_breakdown == expected_score_breakdown + assert score == score_breakdown['score'] + + def test_with_no_minority_class_in_validation(self): + """Test the metric when the minority class is not present in the validation data.""" + # Setup + np.random.seed(0) + real_data, synthetic_data, metadata = load_demo(modality='single_table') + mask_validation = np.random.rand(len(real_data)) < 0.8 + real_training = real_data[mask_validation] + real_validation = real_data[~mask_validation] + real_validation['gender'] = 'M' + expected_error = re.escape( + "The metric can't be computed because the value `F` is not present in the column " + '`gender` for the real validation data. The `precision`and `recall` are undefined' + ' for this case.' + ) + + # Run and Assert + with pytest.raises(ValueError, match=expected_error): + BinaryClassifierPrecisionEfficacy.compute( + real_training_data=real_training, + synthetic_data=synthetic_data, + real_validation_data=real_validation, + metadata=metadata, + prediction_column_name='gender', + minority_class_label='F', + classifier='XGBoost', + fixed_recall_value=0.8, + ) + + def test_with_nan_target_column(self): + """Test the metric when the target column has NaN values.""" + # Setup + np.random.seed(35) + real_data, synthetic_data, metadata = load_demo(modality='single_table') + mask_validation = np.random.rand(len(real_data)) < 0.8 + real_training = real_data[mask_validation].reset_index(drop=True) + real_validation = real_data[~mask_validation].reset_index(drop=True) + real_training.loc[:3, 'gender'] = np.nan + real_validation.loc[:5, 'gender'] = np.nan + + # Run + result_breakdown = BinaryClassifierPrecisionEfficacy.compute_breakdown( + real_training_data=real_training, + synthetic_data=synthetic_data, + real_validation_data=real_validation, + metadata=metadata, + prediction_column_name='gender', + minority_class_label='F', + classifier='XGBoost', + fixed_recall_value=0.8, + ) + + # Assert + expected_result = { + 'real_data_baseline': { + 'recall_score_training': 0.8135593220338984, + 'recall_score_validation': 0.23076923076923078, + 'precision_score_validation': 0.42857142857142855, + 'prediction_counts_validation': { + 'true_positive': 3, + 'false_positive': 4, + 'true_negative': 29, + 'false_negative': 10, + }, + }, + 'augmented_data': { + 'recall_score_training': 0.8, + 'recall_score_validation': 0.15384615384615385, + 'precision_score_validation': 0.4, + 'prediction_counts_validation': { + 'true_positive': 2, + 'false_positive': 3, + 'true_negative': 30, + 'false_negative': 11, + }, + }, + 'parameters': { + 'prediction_column_name': 'gender', + 'minority_class_label': 'F', + 'classifier': 'XGBoost', + 'fixed_recall_value': 0.8, + }, + 'score': 0, + } + assert result_breakdown == expected_result + + def test_with_minority_being_majority(self): + """Test the metric when the minority class is the majority class.""" + # Setup + np.random.seed(0) + real_data, synthetic_data, metadata = load_demo(modality='single_table') + mask_validation = np.random.rand(len(real_data)) < 0.8 + real_training = real_data[mask_validation] + real_validation = real_data[~mask_validation] + + # Run + score = BinaryClassifierPrecisionEfficacy.compute( + real_training_data=real_training, + synthetic_data=synthetic_data, + real_validation_data=real_validation, + metadata=metadata, + prediction_column_name='gender', + minority_class_label='F', + classifier='XGBoost', + fixed_recall_value=0.8, + ) + + # Assert + assert score == 0 + + def test_with_multi_class(self): + """Test the metric with multi-class classification. + + The `high_spec` column has 3 values `Commerce`, `Science`, and `Arts`. + """ + # Setup + np.random.seed(0) + real_data, synthetic_data, metadata = load_demo(modality='single_table') + mask_validation = np.random.rand(len(real_data)) < 0.8 + real_training = real_data[mask_validation] + real_validation = real_data[~mask_validation] + + # Run + score_breakdown = BinaryClassifierPrecisionEfficacy.compute_breakdown( + real_training_data=real_training, + synthetic_data=synthetic_data, + real_validation_data=real_validation, + metadata=metadata, + prediction_column_name='high_spec', + minority_class_label='Science', + classifier='XGBoost', + fixed_recall_value=0.8, + ) + + # Assert + expected_score_breakdown = { + 'real_data_baseline': { + 'recall_score_training': 0.8076923076923077, + 'recall_score_validation': 0.6923076923076923, + 'precision_score_validation': 0.9, + 'prediction_counts_validation': { + 'true_positive': 9, + 'false_positive': 1, + 'true_negative': 24, + 'false_negative': 4, + }, + }, + 'augmented_data': { + 'recall_score_training': 0.8035714285714286, + 'recall_score_validation': 0.6153846153846154, + 'precision_score_validation': 0.8888888888888888, + 'prediction_counts_validation': { + 'true_positive': 8, + 'false_positive': 1, + 'true_negative': 24, + 'false_negative': 5, + }, + }, + 'parameters': { + 'prediction_column_name': 'high_spec', + 'minority_class_label': 'Science', + 'classifier': 'XGBoost', + 'fixed_recall_value': 0.8, + }, + 'score': 0, + } + assert score_breakdown == expected_score_breakdown diff --git a/tests/unit/single_table/data_augmentation/__init__.py b/tests/unit/single_table/data_augmentation/__init__.py new file mode 100644 index 00000000..d3856339 --- /dev/null +++ b/tests/unit/single_table/data_augmentation/__init__.py @@ -0,0 +1 @@ +"""SDMetrics unit testing for the single_table data_augmentation module.""" diff --git a/tests/unit/single_table/data_augmentation/test_base.py b/tests/unit/single_table/data_augmentation/test_base.py new file mode 100644 index 00000000..0ff25b9d --- /dev/null +++ b/tests/unit/single_table/data_augmentation/test_base.py @@ -0,0 +1,425 @@ +"""Test for the base BaseDataAugmentationMetric metrics.""" + +from unittest.mock import Mock, patch + +import numpy as np +import pandas as pd +import pytest +from sklearn.metrics import precision_score +from xgboost import XGBClassifier + +from sdmetrics.single_table.data_augmentation.base import ( + BaseDataAugmentationMetric, + ClassifierTrainer, +) + + +@pytest.fixture +def real_training_data(): + return pd.DataFrame({ + 'target': [1, 0, 0], + 'numerical': [1, 2, 3], + 'categorical': ['a', 'b', 'b'], + 'boolean': [True, False, True], + 'datetime': pd.to_datetime(['2021-01-01', '2021-01-02', '2021-01-03']), + }) + + +@pytest.fixture +def synthetic_data(): + return pd.DataFrame({ + 'target': [0, 1, 0], + 'numerical': [2, 2, 3], + 'categorical': ['b', 'a', 'b'], + 'boolean': [False, True, False], + 'datetime': pd.to_datetime(['2021-01-25', '2021-01-02', '2021-01-03']), + }) + + +@pytest.fixture +def real_validation_data(): + return pd.DataFrame({ + 'target': [1, 0, 0], + 'numerical': [3, 3, 3], + 'categorical': ['a', 'b', 'b'], + 'boolean': [True, False, True], + 'datetime': pd.to_datetime(['2021-01-01', '2021-01-12', '2021-01-03']), + }) + + +@pytest.fixture +def metadata(): + return { + 'columns': { + 'target': {'sdtype': 'categorical'}, + 'numerical': {'sdtype': 'numerical'}, + 'categorical': {'sdtype': 'categorical'}, + 'boolean': {'sdtype': 'boolean'}, + 'datetime': {'sdtype': 'datetime'}, + } + } + + +class TestClassifierTrainer: + def test__init__(self): + """Test the ``__init__`` method.""" + # Run + trainer = ClassifierTrainer('target', 1, 'XGBoost', 0.69, 'recall') + + # Assert + assert trainer.prediction_column_name == 'target' + assert trainer.minority_class_label == 1 + assert trainer._classifier_name == 'XGBoost' + assert trainer.fixed_value == 0.69 + assert trainer.metric_name == 'recall' + assert trainer._metric_to_fix == 'precision' + assert trainer._metric_method == precision_score + assert isinstance(trainer._classifier, XGBClassifier) + + @patch('sdmetrics.single_table.data_augmentation.base.precision_recall_curve') + def test_get_best_threshold(self, mock_precision_recall_curve, real_training_data): + """Test the ``get_best_threshold`` method.""" + # Setup + trainer = ClassifierTrainer('target', 1, 'XGBoost', 0.69, 'recall') + trainer._classifier = Mock() + trainer._classifier.predict_proba = Mock( + return_value=np.array([[0.1, 0.9], [0.9, 0.1], [0.9, 0.1]]) + ) + mock_precision_recall_curve.return_value = [ + np.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.72, 0.8, 0.9, 1.0]), + np.array([0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1, 0.0, 0.0]), + np.array([0.02, 0.15, 0.25, 0.35, 0.42, 0.51, 0.63, 0.77, 0.82, 0.93, 0.97]), + ] + train_data = real_training_data[['numerical']] + train_target = real_training_data['target'] + + # Run + best_threshold = trainer.get_best_threshold(train_data, train_target) + + # Assert + assert best_threshold == 0.63 + + def test_train_model(self, real_training_data): + """Test the ``train_model`` method. + + Here the true target values are [1, 0, 0] and the predicted ones based on the + best threshold are [1, 0, 1]. So the precision score should be 0.5. + """ + # Setup + trainer = ClassifierTrainer('target', 1, 'XGBoost', 0.69, 'recall') + trainer.get_best_threshold = Mock(return_value=0.63) + trainer._classifier = Mock() + trainer._classifier.predict_proba = Mock( + return_value=np.array([[0.3, 0.7], [0.4, 0.6], [0.3, 0.7]]) + ) + trainer._metric_method = precision_score + real_training_data_copy = real_training_data.copy() + + # Run + score = trainer.train_model(real_training_data_copy) + + # Assert + assert score == 0.5 + assert trainer._best_threshold == 0.63 + + def test_compute_validation_scores(self, real_validation_data): + """Test the ``compute_validation_scores`` method.""" + # Setup + trainer = ClassifierTrainer('target', 1, 'XGBoost', 0.69, 'recall') + trainer._best_threshold = 0.63 + trainer._classifier = Mock() + trainer._classifier.predict_proba = Mock( + return_value=np.array([[0.3, 0.7], [0.4, 0.6], [0.3, 0.7]]) + ) + + # Run + recall, precision, prediction_counts_validation = trainer.compute_validation_scores( + real_validation_data + ) + + # Assert + assert recall == 1.0 + assert precision == 0.5 + assert prediction_counts_validation == { + 'true_positive': 1, + 'false_positive': 1, + 'true_negative': 1, + 'false_negative': 0, + } + + def test_get_scores(self, real_training_data, real_validation_data): + """Test the ``get_scores`` method.""" + # Setup + trainer = ClassifierTrainer('target', 1, 'XGBoost', 0.69, 'precision') + trainer.train_model = Mock(return_value=0.78) + trainer.compute_validation_scores = Mock( + return_value=( + 1.0, + 0.5, + { + 'true_positive': 1, + 'false_positive': 1, + 'true_negative': 1, + 'false_negative': 0, + }, + ) + ) + + # Run + scores = trainer.get_scores(real_training_data, real_validation_data) + + # Assert + assert scores == { + 'recall_score_training': 0.78, + 'recall_score_validation': 1.0, + 'precision_score_validation': 0.5, + 'prediction_counts_validation': { + 'true_positive': 1, + 'false_positive': 1, + 'true_negative': 1, + 'false_negative': 0, + }, + } + + +class TestBaseDataAugmentationMetric: + """Test the BaseDataAugmentationMetric class.""" + + def test__fit(self, real_training_data, metadata): + """Test the ``_fit`` method.""" + # Setup + metric = BaseDataAugmentationMetric() + + # Run + discrete_columns, datetime_columns = metric._fit(real_training_data, metadata, 'target') + + # Assert + assert discrete_columns == ['categorical', 'boolean'] + assert datetime_columns == ['datetime'] + + def test__transform(self, real_training_data, synthetic_data, real_validation_data): + """Test the ``_transform`` method.""" + # Setup + metric = BaseDataAugmentationMetric() + discrete_columns = ['categorical', 'boolean'] + datetime_columns = ['datetime'] + tables = { + 'real_training_data': real_training_data, + 'synthetic_data': synthetic_data, + 'real_validation_data': real_validation_data, + } + + # Run + transformed = metric._transform(tables, discrete_columns, datetime_columns, 'target', 1) + + # Assert + expected_transformed = { + 'real_training_data': pd.DataFrame({ + 'target': [1, 0, 0], + 'numerical': [1, 2, 3], + 'categorical': pd.Categorical(['a', 'b', 'b']), + 'boolean': pd.Categorical([True, False, True]), + 'datetime': pd.to_numeric( + pd.to_datetime(['2021-01-01', '2021-01-02', '2021-01-03']) + ), + }), + 'synthetic_data': pd.DataFrame({ + 'target': [0, 1, 0], + 'numerical': [2, 2, 3], + 'categorical': pd.Categorical(['b', 'a', 'b']), + 'boolean': pd.Categorical([False, True, False]), + 'datetime': pd.to_numeric( + pd.to_datetime(['2021-01-25', '2021-01-02', '2021-01-03']) + ), + }), + 'real_validation_data': pd.DataFrame({ + 'target': [1, 0, 0], + 'numerical': [3, 3, 3], + 'categorical': pd.Categorical(['a', 'b', 'b']), + 'boolean': pd.Categorical([True, False, True]), + 'datetime': pd.to_numeric( + pd.to_datetime(['2021-01-01', '2021-01-12', '2021-01-03']) + ), + }), + } + for table_name, table in transformed.items(): + pd.testing.assert_frame_equal( + table, expected_transformed[table_name], check_dtype=False + ) + + def test__fit_transform( + self, real_training_data, synthetic_data, real_validation_data, metadata + ): + """Test the ``_fit_transform`` method.""" + # Setup + metric = BaseDataAugmentationMetric() + BaseDataAugmentationMetric._fit = Mock() + discrete_columns = ['categorical', 'boolean'] + datetime_columns = ['datetime'] + BaseDataAugmentationMetric._fit.return_value = ( + discrete_columns, + datetime_columns, + ) + tables = { + 'real_training_data': real_training_data, + 'synthetic_data': synthetic_data, + 'real_validation_data': real_validation_data, + } + BaseDataAugmentationMetric._transform = Mock(return_value=tables) + + # Run + transformed = metric._fit_transform( + real_training_data, synthetic_data, real_validation_data, metadata, 'target', 1 + ) + + # Assert + BaseDataAugmentationMetric._fit.assert_called_once_with( + real_training_data, metadata, 'target' + ) + BaseDataAugmentationMetric._transform.assert_called_once_with( + tables, discrete_columns, datetime_columns, 'target', 1 + ) + for table_name, table in transformed.items(): + assert table.equals(tables[table_name]) + + @patch('sdmetrics.single_table.data_augmentation.base._validate_inputs') + @patch( + 'sdmetrics.single_table.data_augmentation.base.BaseDataAugmentationMetric._fit_transform' + ) + @patch( + 'sdmetrics.single_table.data_augmentation.base.ClassifierTrainer', + ) + @patch.object(BaseDataAugmentationMetric, 'metric_name', 'precision') + def test_compute_breakdown( + self, + mock_classifier_trainer, + mock_fit_transfrom, + mock_validate_inputs, + real_training_data, + synthetic_data, + real_validation_data, + metadata, + ): + """Test the ``compute_breakdown`` method.""" + # Setup + prediction_column_name = 'target' + minority_class_label = 1 + classifier = 'XGBoost' + fixed_recall_value = 0.9 + + real_data_baseline = { + 'precision_score_training': 0.43, + 'recall_score_validation': 0.7, + 'precision_score_validation': 0.5, + 'prediction_counts_validation': { + 'true_positive': 1, + 'false_positive': 1, + 'true_negative': 1, + 'false_negative': 0, + }, + } + augmented_table_result = { + 'precision_score_training': 0.78, + 'recall_score_validation': 0.9, + 'precision_score_validation': 0.7, + 'prediction_counts_validation': { + 'true_positive': 2, + 'false_positive': 2, + 'true_negative': 1, + 'false_negative': 0, + }, + } + mock_fit_transfrom.return_value = { + 'real_training_data': real_training_data, + 'synthetic_data': synthetic_data, + 'real_validation_data': real_validation_data, + } + mock_classifier_trainer.return_value.get_scores.side_effect = [ + real_data_baseline, + augmented_table_result, + ] + mock_classifier_trainer.return_value.prediction_column_name = prediction_column_name + mock_classifier_trainer.return_value.minority_class_label = minority_class_label + mock_classifier_trainer.return_value._classifier_name = classifier + mock_classifier_trainer.return_value.fixed_value = fixed_recall_value + + # Run + score_breakdown = BaseDataAugmentationMetric.compute_breakdown( + real_training_data, + synthetic_data, + real_validation_data, + metadata, + prediction_column_name, + minority_class_label, + classifier, + fixed_recall_value, + ) + + # Assert + expected_result = { + 'score': 0.19999999999999996, + 'real_data_baseline': real_data_baseline, + 'augmented_data': augmented_table_result, + 'parameters': { + 'prediction_column_name': prediction_column_name, + 'minority_class_label': minority_class_label, + 'classifier': classifier, + 'fixed_recall_value': fixed_recall_value, + }, + } + assert score_breakdown == expected_result + mock_validate_inputs.assert_called_once_with( + real_training_data, + synthetic_data, + real_validation_data, + metadata, + prediction_column_name, + minority_class_label, + classifier, + fixed_recall_value, + ) + mock_fit_transfrom.assert_called_once_with( + real_training_data, + synthetic_data, + real_validation_data, + metadata, + prediction_column_name, + minority_class_label, + ) + + @patch( + 'sdmetrics.single_table.data_augmentation.base.BaseDataAugmentationMetric.compute_breakdown' + ) + def test_compute( + self, + mock_compute_breakdown, + real_training_data, + synthetic_data, + real_validation_data, + metadata, + ): + """Test the ``compute`` method.""" + # Setup + prediction_column_name = 'target' + minority_class_label = 1 + classifier = 'XGBoost' + fixed_recall_value = 0.9 + mock_compute_breakdown.return_value = { + 'score': 0.9, + 'other_key': 'other_value', + } + + # Run + score = BaseDataAugmentationMetric.compute( + real_training_data, + synthetic_data, + real_validation_data, + metadata, + prediction_column_name, + minority_class_label, + classifier, + fixed_recall_value, + ) + + # Assert + assert score == 0.9 diff --git a/tests/unit/single_table/data_augmentation/test_binary_classifier_precision_efficacy.py b/tests/unit/single_table/data_augmentation/test_binary_classifier_precision_efficacy.py new file mode 100644 index 00000000..36052a4c --- /dev/null +++ b/tests/unit/single_table/data_augmentation/test_binary_classifier_precision_efficacy.py @@ -0,0 +1,17 @@ +"""Test for the Binary Classifier Precision Efficacy metrics.""" + +from sdmetrics.single_table.data_augmentation.binary_classifier_precision_efficacy import ( + BinaryClassifierPrecisionEfficacy, +) + + +class TestBinaryClassifierPrecisionEfficacy: + def test_class_attributes(self): + """Test the class attributes.""" + # Setup + expected_name = 'Binary Classifier Precision Efficacy' + expected_metric_name = 'precision' + + # Run and Assert + assert BinaryClassifierPrecisionEfficacy.name == expected_name + assert BinaryClassifierPrecisionEfficacy.metric_name == expected_metric_name diff --git a/tests/unit/single_table/data_augmentation/test_utils.py b/tests/unit/single_table/data_augmentation/test_utils.py new file mode 100644 index 00000000..c7435ae0 --- /dev/null +++ b/tests/unit/single_table/data_augmentation/test_utils.py @@ -0,0 +1,203 @@ +import re +from copy import deepcopy +from unittest.mock import patch + +import pandas as pd +import pytest + +from sdmetrics.single_table.data_augmentation.utils import ( + _validate_data_and_metadata, + _validate_inputs, + _validate_parameters, +) + + +def test__validate_parameters(): + """Test the ``_validate_parameters`` method.""" + # Setup + expected_message_dataframes = re.escape( + '`real_training_data`, `synthetic_data` and `real_validation_data` must be' + ' pandas DataFrames.' + ) + expected_message_metadata = re.escape( + "Expected a dictionary but received a 'list' instead." + " For SDV metadata objects, please use the 'to_dict' function to convert it" + ' to a dictionary.' + ) + expected_message_prediction_column_name = re.escape( + '`prediction_column_name` must be a string.' + ) + expected_message_classifier = re.escape('`classifier` must be a string or None.') + expected_message_classifier_value = re.escape( + 'Currently only `XGBoost` is supported as classifier.' + ) + expected_message_fixed_recall_value = re.escape( + '`fixed_recall_value` must be a float in the range (0, 1).' + ) + inputs = { + 'real_training_data': pd.DataFrame({'target': [1, 0, 0]}), + 'synthetic_data': pd.DataFrame({'target': [1, 0, 0]}), + 'real_validation_data': pd.DataFrame({'target': [1, 0, 0]}), + 'metadata': {'columns': {'target': {'sdtype': 'categorical'}}}, + 'prediction_column_name': 'target', + 'classifier': 'XGBoost', + 'fixed_recall_value': 0.9, + } + + # Run and Assert + _validate_parameters(**inputs) + wrong_inputs_dataframes = deepcopy(inputs) + wrong_inputs_dataframes['real_training_data'] = 'wrong' + with pytest.raises(ValueError, match=expected_message_dataframes): + _validate_parameters(**wrong_inputs_dataframes) + + wrong_inputs_metadata = deepcopy(inputs) + wrong_inputs_metadata['metadata'] = [] + with pytest.raises(TypeError, match=expected_message_metadata): + _validate_parameters(**wrong_inputs_metadata) + + wrong_inputs_prediction_column_name = deepcopy(inputs) + wrong_inputs_prediction_column_name['prediction_column_name'] = 1 + with pytest.raises(TypeError, match=expected_message_prediction_column_name): + _validate_parameters(**wrong_inputs_prediction_column_name) + + wrong_inputs_classifier_type = deepcopy(inputs) + wrong_inputs_classifier_type['classifier'] = 1 + with pytest.raises(TypeError, match=expected_message_classifier): + _validate_parameters(**wrong_inputs_classifier_type) + + wrong_inputs_classifier = deepcopy(inputs) + wrong_inputs_classifier['classifier'] = 'LogisticRegression' + with pytest.raises(ValueError, match=expected_message_classifier_value): + _validate_parameters(**wrong_inputs_classifier) + + wrong_inputs_fixed_recall_value_type = deepcopy(inputs) + wrong_inputs_fixed_recall_value_type['fixed_recall_value'] = '0.9' + with pytest.raises(TypeError, match=expected_message_fixed_recall_value): + _validate_parameters(**wrong_inputs_fixed_recall_value_type) + + wrong_inputs_fixed_recall_value = deepcopy(inputs) + wrong_inputs_fixed_recall_value['fixed_recall_value'] = 1.2 + with pytest.raises(TypeError, match=expected_message_fixed_recall_value): + _validate_parameters(**wrong_inputs_fixed_recall_value) + + +def test__validate_data_and_metadata(): + """Test the ``_validate_data_and_metadata`` method.""" + # Setup + inputs = { + 'real_training_data': pd.DataFrame({'target': [1, 0, 0]}), + 'synthetic_data': pd.DataFrame({'target': [1, 0, 0]}), + 'real_validation_data': pd.DataFrame({'target': [1, 0, 0]}), + 'metadata': {'columns': {'target': {'sdtype': 'categorical'}}}, + 'prediction_column_name': 'target', + 'minority_class_label': 1, + } + expected_message_missing_prediction_column = re.escape( + 'The column `target` is not described in the metadata. Please update your metadata.' + ) + expected_message_sdtype = re.escape( + 'The column `target` must be either categorical or boolean. Please update your metadata.' + ) + expected_message_column_missmatch = re.escape( + '`real_training_data`, `synthetic_data` and `real_validation_data` must have the ' + 'same columns and must match the columns described in the metadata.' + ) + expected_message_value = re.escape( + 'The value `1` is not present in the column `target` for the real training data.' + ) + expected_error_missing_minority = re.escape( + "The metric can't be computed because the value `1` is not present in " + 'the column `target` for the real validation data. The `precision`and `recall`' + ' are undefined for this case.' + ) + expected_error_synthetic_wrong_label = re.escape( + 'The ``target`` column must have the same values in the real and synthetic data. ' + 'The following values are present in the synthetic data and not the real' + " data: 'wrong_1', 'wrong_2'" + ) + + # Run and Assert + _validate_data_and_metadata(**inputs) + missing_prediction_column = deepcopy(inputs) + missing_prediction_column['metadata']['columns'].pop('target') + with pytest.raises(ValueError, match=expected_message_missing_prediction_column): + _validate_data_and_metadata(**missing_prediction_column) + + wrong_inputs_sdtype = deepcopy(inputs) + wrong_inputs_sdtype['metadata']['columns']['target']['sdtype'] = 'numerical' + with pytest.raises(ValueError, match=expected_message_sdtype): + _validate_data_and_metadata(**wrong_inputs_sdtype) + + wrong_column_metadata = deepcopy(inputs) + wrong_column_metadata['metadata']['columns'].update({'new_column': {'sdtype': 'categorical'}}) + with pytest.raises(ValueError, match=expected_message_column_missmatch): + _validate_data_and_metadata(**wrong_column_metadata) + + wrong_column_data = deepcopy(inputs) + wrong_column_data['real_training_data'] = pd.DataFrame({'new_column': [1, 0, 0]}) + with pytest.raises(ValueError, match=expected_message_column_missmatch): + _validate_data_and_metadata(**wrong_column_data) + + missing_minority_class_label = deepcopy(inputs) + missing_minority_class_label['real_training_data'] = pd.DataFrame({'target': [0, 0, 0]}) + with pytest.raises(ValueError, match=expected_message_value): + _validate_data_and_metadata(**missing_minority_class_label) + + missing_minority_class_label_validation = deepcopy(inputs) + missing_minority_class_label_validation['real_validation_data'] = pd.DataFrame({ + 'target': [0, 0, 0] + }) + with pytest.raises(ValueError, match=expected_error_missing_minority): + _validate_data_and_metadata(**missing_minority_class_label_validation) + + wrong_synthetic_label = deepcopy(inputs) + wrong_synthetic_label['synthetic_data'] = pd.DataFrame({'target': [0, 1, 'wrong_1', 'wrong_2']}) + with pytest.raises(ValueError, match=expected_error_synthetic_wrong_label): + _validate_data_and_metadata(**wrong_synthetic_label) + + +@patch('sdmetrics.single_table.data_augmentation.utils._validate_parameters') +@patch('sdmetrics.single_table.data_augmentation.utils._validate_data_and_metadata') +def test__validate_inputs_mock(mock_validate_data_and_metadata, mock_validate_parameters): + """Test the ``validate_inputs`` method.""" + # Setup + real_training_data = pd.DataFrame({'target': [1, 0, 0]}) + synthetic_data = pd.DataFrame({'target': [1, 0, 0]}) + real_validation_data = pd.DataFrame({'target': [1, 0, 0]}) + metadata = {'columns': {'target': {'sdtype': 'categorical'}}} + prediction_column_name = 'target' + minority_class_label = 1 + classifier = 'XGBoost' + fixed_recall_value = 0.9 + + # Run + _validate_inputs( + real_training_data, + synthetic_data, + real_validation_data, + metadata, + prediction_column_name, + minority_class_label, + classifier, + fixed_recall_value, + ) + + # Assert + mock_validate_data_and_metadata.assert_called_once_with( + real_training_data, + synthetic_data, + real_validation_data, + metadata, + prediction_column_name, + minority_class_label, + ) + mock_validate_parameters.assert_called_once_with( + real_training_data, + synthetic_data, + real_validation_data, + metadata, + prediction_column_name, + classifier, + fixed_recall_value, + )