diff --git a/sdmetrics/_utils_metadata.py b/sdmetrics/_utils_metadata.py new file mode 100644 index 00000000..f857a1b3 --- /dev/null +++ b/sdmetrics/_utils_metadata.py @@ -0,0 +1,155 @@ +import warnings + +import pandas as pd + +MODELABLE_SDTYPES = ('numerical', 'datetime', 'categorical', 'boolean') + + +def _validate_metadata_dict(metadata): + """Validate the metadata type.""" + if not isinstance(metadata, dict): + raise TypeError( + f"Expected a dictionary but received a '{type(metadata).__name__}' instead." + " For SDV metadata objects, please use the 'to_dict' function to convert it" + ' to a dictionary.' + ) + + +def _validate_single_table_metadata(metadata): + """Validate the metadata for a single table.""" + _validate_metadata_dict(metadata) + if 'columns' not in metadata: + raise ValueError( + "Single-table metadata must include a 'columns' key that maps column names" + ' to their corresponding information.' + ) + + +def _validate_multi_table_metadata(metadata): + """Validate the metadata for multiple tables.""" + _validate_metadata_dict(metadata) + if 'tables' not in metadata: + raise ValueError( + "Multi-table metadata must include a 'tables' key that maps table names" + ' to their respective metadata.' + ) + for table_name, table_metadata in metadata['tables'].items(): + try: + _validate_single_table_metadata(table_metadata) + except ValueError as e: + raise ValueError(f"Error in table '{table_name}': {str(e)}") + + +def _validate_metadata(metadata): + """Validate the metadata.""" + _validate_metadata_dict(metadata) + if ('columns' not in metadata) and ('tables' not in metadata): + raise ValueError( + "Metadata must include either a 'columns' key for single-table metadata" + " or a 'tables' key for multi-table metadata." + ) + + if 'tables' in metadata: + _validate_multi_table_metadata(metadata) + + +def handle_single_and_multi_table(single_table_func): + """Decorator to handle both single and multi table functions.""" + + def wrapper(data, metadata): + if isinstance(data, pd.DataFrame): + return single_table_func(data, metadata) + + result = {} + for table_name in data: + result[table_name] = single_table_func(data[table_name], metadata['tables'][table_name]) + + return result + + return wrapper + + +@handle_single_and_multi_table +def _convert_datetime_columns(data, metadata): + """Convert datetime columns to datetime type.""" + columns_missing_datetime_format = [] + for column in metadata['columns']: + if metadata['columns'][column]['sdtype'] == 'datetime': + is_datetime = pd.api.types.is_datetime64_any_dtype(data[column]) + if not is_datetime: + datetime_format = metadata['columns'][column].get('format') + try: + if datetime_format: + data[column] = pd.to_datetime(data[column], format=datetime_format) + else: + columns_missing_datetime_format.append(column) + data[column] = pd.to_datetime(data[column]) + except Exception as e: + raise ValueError( + f"Failed to convert column '{column}' to datetime with the error: {str(e)}" + ) from e + + if columns_missing_datetime_format: + columns_to_print = "', '".join(columns_missing_datetime_format) + warnings.warn( + f'No `datetime_format` provided in the metadata when trying to convert the columns' + f" '{columns_to_print}' to datetime. The format will be inferred, but it may not" + ' be accurate.', + UserWarning, + ) + + return data + + +@handle_single_and_multi_table +def _remove_missing_columns_metadata(data, metadata): + """Remove columns that are not present in the metadata.""" + columns_in_metadata = set(metadata['columns'].keys()) + columns_in_data = set(data.columns) + columns_to_remove = columns_in_data - columns_in_metadata + extra_metadata_columns = columns_in_metadata - columns_in_data + if columns_to_remove: + columns_to_print = "', '".join(sorted(columns_to_remove)) + warnings.warn( + f"The columns ('{columns_to_print}') are not present in the metadata." + 'They will not be included for further evaluation.', + UserWarning, + ) + elif extra_metadata_columns: + columns_to_print = "', '".join(sorted(extra_metadata_columns)) + warnings.warn( + f"The columns ('{columns_to_print}') are in the metadata but they are not " + 'present in the data.', + UserWarning, + ) + + data = data.drop(columns=columns_to_remove) + column_intersection = [column for column in data.columns if column in metadata['columns']] + + return data[column_intersection] + + +@handle_single_and_multi_table +def _remove_non_modelable_columns(data, metadata): + """Remove columns that are not modelable. + + All modelable columns are numerical, datetime, categorical, or boolean sdtypes. + """ + columns_modelable = [] + for column in metadata['columns']: + column_sdtype = metadata['columns'][column]['sdtype'] + if column_sdtype in MODELABLE_SDTYPES and column in data.columns: + columns_modelable.append(column) + + return data[columns_modelable] + + +def _process_data_with_metadata(data, metadata, keep_modelable_columns_only=False): + """Process the data according to the metadata.""" + _validate_metadata_dict(metadata) + data = _convert_datetime_columns(data, metadata) + data = _remove_missing_columns_metadata(data, metadata) + if keep_modelable_columns_only: + data = _remove_non_modelable_columns(data, metadata) + + return data diff --git a/sdmetrics/multi_table/detection/parent_child.py b/sdmetrics/multi_table/detection/parent_child.py index b2f4a776..c423229c 100644 --- a/sdmetrics/multi_table/detection/parent_child.py +++ b/sdmetrics/multi_table/detection/parent_child.py @@ -2,6 +2,7 @@ import numpy as np +from sdmetrics._utils_metadata import _validate_multi_table_metadata from sdmetrics.multi_table.detection.base import DetectionMetric from sdmetrics.single_table.detection import LogisticDetection, SVCDetection from sdmetrics.utils import get_columns_from_metadata, nested_attrs_meta @@ -37,9 +38,7 @@ class ParentChildDetectionMetric( @staticmethod def _extract_foreign_keys(metadata): - if not isinstance(metadata, dict): - metadata = metadata.to_dict() - + _validate_multi_table_metadata(metadata) foreign_keys = [] for child_table, child_meta in metadata['tables'].items(): for child_key, field_meta in get_columns_from_metadata(child_meta).items(): diff --git a/sdmetrics/multi_table/multi_single_table.py b/sdmetrics/multi_table/multi_single_table.py index 032d7c42..2e34fc28 100644 --- a/sdmetrics/multi_table/multi_single_table.py +++ b/sdmetrics/multi_table/multi_single_table.py @@ -6,6 +6,7 @@ import numpy as np from sdmetrics import single_table +from sdmetrics._utils_metadata import _validate_multi_table_metadata from sdmetrics.errors import IncomputableMetricError from sdmetrics.multi_table.base import MultiTableMetric from sdmetrics.utils import nested_attrs_meta @@ -77,9 +78,8 @@ def _compute(self, real_data, synthetic_data, metadata=None, **kwargs): if metadata is None: metadata = {'tables': defaultdict(type(None))} - elif not isinstance(metadata, dict): - metadata = metadata.to_dict() + _validate_multi_table_metadata(metadata) scores = {} errors = {} for table_name, real_table in real_data.items(): diff --git a/sdmetrics/multi_table/statistical/cardinality_shape_similarity.py b/sdmetrics/multi_table/statistical/cardinality_shape_similarity.py index 54c16cad..c4410248 100644 --- a/sdmetrics/multi_table/statistical/cardinality_shape_similarity.py +++ b/sdmetrics/multi_table/statistical/cardinality_shape_similarity.py @@ -3,6 +3,9 @@ import numpy as np from scipy.stats import ks_2samp +from sdmetrics._utils_metadata import ( + _validate_metadata_dict, +) from sdmetrics.goal import Goal from sdmetrics.multi_table.base import MultiTableMetric from sdmetrics.utils import get_cardinality_distribution @@ -53,9 +56,8 @@ def compute_breakdown(cls, real_data, synthetic_data, metadata): """ if set(real_data.keys()) != set(synthetic_data.keys()): raise ValueError('`real_data` and `synthetic_data` must have the same tables.') - if not isinstance(metadata, dict): - metadata = metadata.to_dict() + _validate_metadata_dict(metadata) score_breakdowns = {} for rel in metadata.get('relationships', []): cardinality_real = get_cardinality_distribution( diff --git a/sdmetrics/multi_table/statistical/cardinality_statistic_similarity.py b/sdmetrics/multi_table/statistical/cardinality_statistic_similarity.py index 7accfbde..a40126f2 100644 --- a/sdmetrics/multi_table/statistical/cardinality_statistic_similarity.py +++ b/sdmetrics/multi_table/statistical/cardinality_statistic_similarity.py @@ -4,6 +4,7 @@ import numpy as np +from sdmetrics._utils_metadata import _validate_metadata_dict from sdmetrics.goal import Goal from sdmetrics.multi_table.base import MultiTableMetric from sdmetrics.utils import get_cardinality_distribution @@ -107,9 +108,8 @@ def compute_breakdown(cls, real_data, synthetic_data, metadata=None, statistic=' raise ValueError('`real_data` and `synthetic_data` must have the same tables.') if metadata is None: raise ValueError('`metadata` cannot be ``None``.') - if not isinstance(metadata, dict): - metadata = metadata.to_dict() + _validate_metadata_dict(metadata) score_breakdowns = {} for rel in metadata.get('relationships', []): cardinality_real = get_cardinality_distribution( diff --git a/sdmetrics/reports/base_report.py b/sdmetrics/reports/base_report.py index 614b8083..8fe6002d 100644 --- a/sdmetrics/reports/base_report.py +++ b/sdmetrics/reports/base_report.py @@ -13,6 +13,7 @@ import pandas as pd import tqdm +from sdmetrics._utils_metadata import _validate_metadata from sdmetrics.reports.utils import convert_datetime_columns from sdmetrics.visualization import set_plotly_config @@ -68,17 +69,6 @@ def _validate_data_format(self, real_data, synthetic_data): ) raise ValueError(error_message) - def _validate_metadata_format(self, metadata): - """Validate the metadata.""" - if not isinstance(metadata, dict): - raise TypeError('The provided metadata is not a dictionary.') - - if 'columns' not in metadata: - raise ValueError( - 'Single table reports expect metadata to contain a "columns" key with a mapping' - ' from column names to column informations.' - ) - def _validate(self, real_data, synthetic_data, metadata): """Validate the inputs. @@ -91,7 +81,7 @@ def _validate(self, real_data, synthetic_data, metadata): The metadata of the table. """ self._validate_data_format(real_data, synthetic_data) - self._validate_metadata_format(metadata) + _validate_metadata(metadata) self._validate_metadata_matches_data(real_data, synthetic_data, metadata) @staticmethod @@ -142,13 +132,6 @@ def generate(self, real_data, synthetic_data, metadata, verbose=True): verbose (bool): Whether or not to print report summary and progress. """ - if not isinstance(metadata, dict): - raise TypeError( - f"Expected a dictionary but received a '{type(metadata).__name__}' instead." - " For SDV metadata objects, please use the 'to_dict' function to convert it" - ' to a dictionary.' - ) - self._validate(real_data, synthetic_data, metadata) self.convert_datetimes(real_data, synthetic_data, metadata) diff --git a/sdmetrics/reports/multi_table/base_multi_table_report.py b/sdmetrics/reports/multi_table/base_multi_table_report.py index 1b78aa07..319c182f 100644 --- a/sdmetrics/reports/multi_table/base_multi_table_report.py +++ b/sdmetrics/reports/multi_table/base_multi_table_report.py @@ -40,22 +40,6 @@ def _validate_data_format(self, real_data, synthetic_data): raise ValueError(error_message) - def _validate_metadata_format(self, metadata): - """Validate the metadata.""" - if not isinstance(metadata, dict): - raise TypeError('The provided metadata is not a dictionary.') - - if 'tables' not in metadata: - raise ValueError( - 'Multi table reports expect metadata to contain a "tables" key with a mapping' - ' from table names to metadata for each table.' - ) - for table_name, table_metadata in metadata['tables'].items(): - if 'columns' not in table_metadata: - raise ValueError( - f'The metadata for table "{table_name}" is missing a "columns" key.' - ) - def _validate_relationships(self, real_data, synthetic_data, metadata): """Validate that the relationships are valid.""" for rel in metadata.get('relationships', []): diff --git a/sdmetrics/single_table/base.py b/sdmetrics/single_table/base.py index 32877b06..25257adc 100644 --- a/sdmetrics/single_table/base.py +++ b/sdmetrics/single_table/base.py @@ -5,6 +5,7 @@ import pandas as pd +from sdmetrics._utils_metadata import _validate_single_table_metadata from sdmetrics.base import BaseMetric from sdmetrics.errors import IncomputableMetricError from sdmetrics.utils import get_alternate_keys, get_columns_from_metadata, get_type_from_column_meta @@ -92,7 +93,7 @@ def _validate_inputs(cls, real_data, synthetic_data, metadata=None): The real data. synthetic_data(pandas.DataFrame): The synthetic data. - metadata (dict or Metadata or None): + metadata (dict): The metadata, if any. Returns: @@ -108,9 +109,7 @@ def _validate_inputs(cls, real_data, synthetic_data, metadata=None): raise ValueError('`real_data` and `synthetic_data` must have the same columns') if metadata is not None: - if not isinstance(metadata, dict): - metadata = metadata.to_dict() - + _validate_single_table_metadata(metadata) fields = get_columns_from_metadata(metadata) for column in real_data.columns: if column not in fields: diff --git a/sdmetrics/single_table/data_augmentation/base.py b/sdmetrics/single_table/data_augmentation/base.py index 4016b48c..784a1112 100644 --- a/sdmetrics/single_table/data_augmentation/base.py +++ b/sdmetrics/single_table/data_augmentation/base.py @@ -9,7 +9,10 @@ from sdmetrics.goal import Goal from sdmetrics.single_table.base import SingleTableMetric -from sdmetrics.single_table.data_augmentation.utils import _validate_inputs +from sdmetrics.single_table.data_augmentation.utils import ( + _process_data_with_metadata_ml_efficacy_metrics, + _validate_inputs, +) METRIC_NAME_TO_METHOD = {'recall': recall_score, 'precision': precision_score} @@ -104,7 +107,11 @@ def _fit(cls, data, metadata, prediction_column_name): """Fit preprocessing parameters.""" discrete_columns = [] datetime_columns = [] - for column, column_meta in metadata['columns'].items(): + data_columns = data.columns + metadata_columns = metadata['columns'].keys() + common_columns = set(data_columns).intersection(metadata_columns) + for column in sorted(common_columns): + column_meta = metadata['columns'][column] if (column_meta['sdtype'] in ['categorical', 'boolean']) and ( column != prediction_column_name ): @@ -192,6 +199,11 @@ def compute_breakdown( classifier, fixed_value, ) + (real_training_data, synthetic_data, real_validation_data) = ( + _process_data_with_metadata_ml_efficacy_metrics( + real_training_data, synthetic_data, real_validation_data, metadata + ) + ) preprocessed_tables = cls._fit_transform( real_training_data, synthetic_data, diff --git a/sdmetrics/single_table/data_augmentation/utils.py b/sdmetrics/single_table/data_augmentation/utils.py index 70156e7d..e2a6c172 100644 --- a/sdmetrics/single_table/data_augmentation/utils.py +++ b/sdmetrics/single_table/data_augmentation/utils.py @@ -2,6 +2,8 @@ import pandas as pd +from sdmetrics._utils_metadata import _process_data_with_metadata, _validate_single_table_metadata + def _validate_tables(real_training_data, synthetic_data, real_validation_data): """Validate the tables of the Data Augmentation metrics.""" @@ -13,16 +15,6 @@ def _validate_tables(real_training_data, synthetic_data, real_validation_data): ) -def _validate_metadata(metadata): - """Validate the metadata of the Data Augmentation metrics.""" - if not isinstance(metadata, dict): - raise TypeError( - f"Expected a dictionary but received a '{type(metadata).__name__}' instead." - " For SDV metadata objects, please use the 'to_dict' function to convert it" - ' to a dictionary.' - ) - - def _validate_prediction_column_name(prediction_column_name): """Validate the prediction column name of the Data Augmentation metrics.""" if not isinstance(prediction_column_name, str): @@ -55,7 +47,7 @@ def _validate_parameters( ): """Validate the parameters of the Data Augmentation metrics.""" _validate_tables(real_training_data, synthetic_data, real_validation_data) - _validate_metadata(metadata) + _validate_single_table_metadata(metadata) _validate_prediction_column_name(prediction_column_name) _validate_classifier(classifier) _validate_fixed_recall_value(fixed_recall_value) @@ -82,18 +74,6 @@ def _validate_data_and_metadata( ' Please update your metadata.' ) - columns_match = ( - set(real_training_data.columns) - == set(synthetic_data.columns) - == set(real_validation_data.columns) - ) - data_metadata_mismatch = set(metadata['columns'].keys()) != set(real_training_data.columns) - if not columns_match or data_metadata_mismatch: - raise ValueError( - '`real_training_data`, `synthetic_data` and `real_validation_data` must have ' - 'the same columns and must match the columns described in the metadata.' - ) - if minority_class_label not in real_training_data[prediction_column_name].unique(): raise ValueError( f'The value `{minority_class_label}` is not present in the column ' @@ -146,3 +126,14 @@ def _validate_inputs( prediction_column_name, minority_class_label, ) + + +def _process_data_with_metadata_ml_efficacy_metrics( + real_training_data, synthetic_data, real_validation_data, metadata +): + """Process the data for ML efficacy metrics according to the metadata.""" + real_training_data = _process_data_with_metadata(real_training_data, metadata, True) + synthetic_data = _process_data_with_metadata(synthetic_data, metadata, True) + real_validation_data = _process_data_with_metadata(real_validation_data, metadata, True) + + return real_training_data, synthetic_data, real_validation_data diff --git a/sdmetrics/timeseries/base.py b/sdmetrics/timeseries/base.py index 95c4fbce..6c157458 100644 --- a/sdmetrics/timeseries/base.py +++ b/sdmetrics/timeseries/base.py @@ -4,6 +4,7 @@ import pandas as pd +from sdmetrics._utils_metadata import _validate_metadata_dict from sdmetrics.base import BaseMetric from sdmetrics.utils import get_columns_from_metadata @@ -51,9 +52,7 @@ def _validate_inputs(cls, real_data, synthetic_data, metadata=None, sequence_key raise ValueError('`real_data` and `synthetic_data` must have the same columns') if metadata is not None: - if not isinstance(metadata, dict): - metadata = metadata.to_dict() - + _validate_metadata_dict(metadata) fields = get_columns_from_metadata(metadata) for column in real_data.columns: if column not in fields: diff --git a/tests/integration/single_table/data_augmentation/test_binary_classifier_precision_efficacy.py b/tests/integration/single_table/data_augmentation/test_binary_classifier_precision_efficacy.py index aceedd1b..7459ee4a 100644 --- a/tests/integration/single_table/data_augmentation/test_binary_classifier_precision_efficacy.py +++ b/tests/integration/single_table/data_augmentation/test_binary_classifier_precision_efficacy.py @@ -44,13 +44,13 @@ def test_end_to_end(self): expected_score_breakdown = { 'real_data_baseline': { 'recall_score_training': 0.8095238095238095, - 'recall_score_validation': 0.07692307692307693, - 'precision_score_validation': 1.0, + 'recall_score_validation': 0.15384615384615385, + 'precision_score_validation': 0.4, 'prediction_counts_validation': { - 'true_positive': 1, - 'false_positive': 0, - 'true_negative': 25, - 'false_negative': 12, + 'true_positive': 2, + 'false_positive': 3, + 'true_negative': 22, + 'false_negative': 11, }, }, 'augmented_data': { @@ -59,8 +59,8 @@ def test_end_to_end(self): 'precision_score_validation': 0.0, 'prediction_counts_validation': { 'true_positive': 0, - 'false_positive': 2, - 'true_negative': 23, + 'false_positive': 1, + 'true_negative': 24, 'false_negative': 13, }, }, @@ -70,7 +70,7 @@ def test_end_to_end(self): 'classifier': 'XGBoost', 'fixed_recall_value': 0.8, }, - 'score': 0, + 'score': 0.3, } assert np.isclose( score_breakdown['real_data_baseline']['recall_score_training'], 0.8, atol=0.1 @@ -147,13 +147,13 @@ def test_with_nan_target_column(self): }, 'augmented_data': { 'recall_score_training': 0.8, - 'recall_score_validation': 0.15384615384615385, - 'precision_score_validation': 0.4, + 'recall_score_validation': 0.23076923076923078, + 'precision_score_validation': 0.6, 'prediction_counts_validation': { - 'true_positive': 2, - 'false_positive': 3, - 'true_negative': 30, - 'false_negative': 11, + 'true_positive': 3, + 'false_positive': 2, + 'true_negative': 31, + 'false_negative': 10, }, }, 'parameters': { @@ -162,7 +162,7 @@ def test_with_nan_target_column(self): 'classifier': 'XGBoost', 'fixed_recall_value': 0.8, }, - 'score': 0.48571428571428577, + 'score': 0.5857142857142857, } assert result_breakdown == expected_result @@ -188,7 +188,7 @@ def test_with_minority_being_majority(self): ) # Assert - assert score == 0 + assert score == 0.3 def test_with_multi_class(self): """Test the metric with multi-class classification. @@ -229,13 +229,13 @@ def test_with_multi_class(self): }, 'augmented_data': { 'recall_score_training': 0.8035714285714286, - 'recall_score_validation': 0.6153846153846154, - 'precision_score_validation': 0.8888888888888888, + 'recall_score_validation': 0.46153846153846156, + 'precision_score_validation': 1.0, 'prediction_counts_validation': { - 'true_positive': 8, - 'false_positive': 1, - 'true_negative': 24, - 'false_negative': 5, + 'true_positive': 6, + 'false_positive': 0, + 'true_negative': 25, + 'false_negative': 7, }, }, 'parameters': { @@ -244,6 +244,6 @@ def test_with_multi_class(self): 'classifier': 'XGBoost', 'fixed_recall_value': 0.8, }, - 'score': 0.4944444444444444, + 'score': 0.55, } assert score_breakdown == expected_score_breakdown diff --git a/tests/integration/single_table/data_augmentation/test_binary_classifier_recall_efficacy.py b/tests/integration/single_table/data_augmentation/test_binary_classifier_recall_efficacy.py index fd9e6843..66c40dee 100644 --- a/tests/integration/single_table/data_augmentation/test_binary_classifier_recall_efficacy.py +++ b/tests/integration/single_table/data_augmentation/test_binary_classifier_recall_efficacy.py @@ -124,7 +124,7 @@ def test_with_nan_target_column(self): ) # Assert - assert result_breakdown['score'] in (0.5, 0.5384615384615385) + assert result_breakdown['score'] in (0.3846153846153846, 0.6538461538461539) def test_with_minority_being_majority(self): """Test the metric when the minority class is the majority class.""" @@ -148,7 +148,7 @@ def test_with_minority_being_majority(self): ) # Assert - assert score == 0.46153846153846156 + assert score in (0.5, 0.3846153846153846) def test_with_multi_class(self): """Test the metric with multi-class classification. @@ -175,4 +175,43 @@ def test_with_multi_class(self): ) # Assert - assert score_breakdown['score'] in (0.46153846153846156, 0.5384615384615384) + assert score_breakdown['score'] in (0.4230769230769231, 0.5384615384615384) + + def test_speical_data_metadata_config(self): + """Test the metric with a special data and metadata configuration. + + in this test: + - The `start_date` column is an object datetime column. + - The synthetic data has an extra column compared to the metadata (`extra_column`). + - The metadata has an extra column compared to the data (`extra_metadata_column`). + """ + # Setup + np.random.seed(0) + real_data, synthetic_data, metadata = load_demo(modality='single_table') + metadata['columns']['extra_metadata_column'] = {'sdtype': 'categorical'} + synthetic_data['extra_column'] = 'extra' + real_data['start_date'] = real_data['start_date'].astype(str) + synthetic_data['start_date'] = synthetic_data['start_date'].astype(str) + mask_validation = np.random.rand(len(real_data)) < 0.8 + real_training = real_data[mask_validation] + real_validation = real_data[~mask_validation] + warning_datetime = re.escape( + 'No `datetime_format` provided in the metadata when trying to convert the columns' + " 'start_date' to datetime. The format will be inferred, but it may not be accurate." + ) + + # Run + with pytest.warns(UserWarning, match=warning_datetime): + score = BinaryClassifierRecallEfficacy.compute( + real_training_data=real_training, + synthetic_data=synthetic_data, + real_validation_data=real_validation, + metadata=metadata, + prediction_column_name='gender', + minority_class_label='F', + classifier='XGBoost', + fixed_precision_value=0.8, + ) + + # Assert + assert score in (0.5, 0.3846153846153846) diff --git a/tests/unit/multi_table/statistical/test_cardinality_statistic_similarity.py b/tests/unit/multi_table/statistical/test_cardinality_statistic_similarity.py index 74804500..f1ece8dc 100644 --- a/tests/unit/multi_table/statistical/test_cardinality_statistic_similarity.py +++ b/tests/unit/multi_table/statistical/test_cardinality_statistic_similarity.py @@ -158,14 +158,14 @@ def test_compute_breakdown(self): # Setup metadata = { 'tables': { - 'tableA': {'fields': {'col1': {}}}, + 'tableA': {'columns': {'col1': {}}}, 'tableB': { - 'fields': { + 'columns': { 'col1': {}, 'col2': {}, }, }, - 'tableC': {'fields': {'col2': {}}}, + 'tableC': {'columns': {'col2': {}}}, }, 'relationships': [ { @@ -230,9 +230,9 @@ def test_compute_breakdown_no_relationships(self): # Setup metadata = { 'tables': { - 'tableA': {'fields': {'col1': {}}}, - 'tableB': {'fields': {'col1': {}, 'col2': {}}}, - 'tableC': {'fields': {'col2': {}}}, + 'tableA': {'columns': {'col1': {}}}, + 'tableB': {'columns': {'col1': {}, 'col2': {}}}, + 'tableC': {'columns': {'col2': {}}}, }, } real_data = { diff --git a/tests/unit/reports/multi_table/test_base_multi_table_report.py b/tests/unit/reports/multi_table/test_base_multi_table_report.py index f5b50c9b..b2f308e0 100644 --- a/tests/unit/reports/multi_table/test_base_multi_table_report.py +++ b/tests/unit/reports/multi_table/test_base_multi_table_report.py @@ -41,53 +41,6 @@ def test__validate_data_format(self): with pytest.raises(ValueError, match=expected_message): base_report._validate_data_format(real_data, synthetic_data) - def test__validate_metadata_format(self): - """Test the ``_validate_metadata_format`` method. - - This test checks that the method raises an error when the metadata is not a dictionnary. - """ - # Setup - base_report = BaseMultiTableReport() - metadata = [] - - # Run and Assert - expected_message = 'The provided metadata is not a dictionary.' - with pytest.raises(TypeError, match=expected_message): - base_report._validate_metadata_format(metadata) - - def test__validate_metadata_format_with_no_tables(self): - """Test the ``_validate_metadata_format`` method. - - This test checks that the method raises an error when the metadata does not contain a - 'tables' key. - """ - # Setup - base_report = BaseMultiTableReport() - metadata = {} - - # Run and Assert - expected_message = ( - 'Multi table reports expect metadata to contain a "tables" key with a mapping from ' - 'table names to metadata for each table.' - ) - with pytest.raises(ValueError, match=expected_message): - base_report._validate_metadata_format(metadata) - - def test__validate_metadata_format_with_no_columns(self): - """Test the ``_validate_metadata_format`` method. - - This test checks that the method raises an error when the metadata does not contain a - 'columns' key. - """ - # Setup - base_report = BaseMultiTableReport() - metadata = {'tables': {'Table_1': {}}} - - # Run and Assert - expected_message = 'The metadata for table "Table_1" is missing a "columns" key.' - with pytest.raises(ValueError, match=expected_message): - base_report._validate_metadata_format(metadata) - def test__validate_relationships(self): """Test the ``_validate_relationships`` method.""" # Setup diff --git a/tests/unit/reports/test_base_report.py b/tests/unit/reports/test_base_report.py index 3a82d7f9..4e72f54e 100644 --- a/tests/unit/reports/test_base_report.py +++ b/tests/unit/reports/test_base_report.py @@ -31,38 +31,6 @@ def test__validate_data_format(self): with pytest.raises(ValueError, match=expected_message): base_report._validate_data_format(real_data, synthetic_data) - def test__validate_metadata_format(self): - """Test the ``_validate_metadata_format`` method. - - This test checks that the method raises an error when the metadata is not a dictionary. - """ - # Setup - base_report = BaseReport() - metadata = 'metadata' - - # Run and Assert - expected_message = 'The provided metadata is not a dictionary.' - with pytest.raises(TypeError, match=expected_message): - base_report._validate_metadata_format(metadata) - - def test__validate_metadata_format_no_columns(self): - """Test the ``_validate_metadata_format`` method. - - This test checks that the method raises an error when the metadata does not contain a - 'columns' key. - """ - # Setup - base_report = BaseReport() - metadata = {} - - # Run and Assert - expected_message = ( - 'Single table reports expect metadata to contain a "columns" key with a mapping' - ' from column names to column informations.' - ) - with pytest.raises(ValueError, match=expected_message): - base_report._validate_metadata_format(metadata) - def test__validate_metadata_matches_data(self): """Test the ``_validate_metadata_matches_data`` method. @@ -137,7 +105,8 @@ def test__validate_metadata_matches_data_no_mismatch(self): # Run and Assert base_report._validate_metadata_matches_data(real_data, synthetic_data, metadata) - def test__validate(self): + @patch('sdmetrics.reports.base_report._validate_metadata') + def test__validate(self, mock__validate_metadata): """Test the ``_validate`` method.""" # Setup base_report = BaseReport() @@ -165,6 +134,7 @@ def test__validate(self): base_report._validate(real_data, synthetic_data, metadata) # Assert + mock__validate_metadata.assert_called_once_with(metadata) mock__validate_metadata_matches_data.assert_called_once_with( real_data, synthetic_data, metadata ) diff --git a/tests/unit/single_table/data_augmentation/test_base.py b/tests/unit/single_table/data_augmentation/test_base.py index 0e919408..2ac43757 100644 --- a/tests/unit/single_table/data_augmentation/test_base.py +++ b/tests/unit/single_table/data_augmentation/test_base.py @@ -194,7 +194,7 @@ def test__fit(self, real_training_data, metadata): discrete_columns, datetime_columns = metric._fit(real_training_data, metadata, 'target') # Assert - assert discrete_columns == ['categorical', 'boolean'] + assert discrete_columns == ['boolean', 'categorical'] assert datetime_columns == ['datetime'] def test__transform(self, real_training_data, synthetic_data, real_validation_data): @@ -282,6 +282,9 @@ def test__fit_transform( for table_name, table in transformed.items(): assert table.equals(tables[table_name]) + @patch( + 'sdmetrics.single_table.data_augmentation.base._process_data_with_metadata_ml_efficacy_metrics' + ) @patch('sdmetrics.single_table.data_augmentation.base._validate_inputs') @patch( 'sdmetrics.single_table.data_augmentation.base.BaseDataAugmentationMetric._fit_transform' @@ -295,6 +298,7 @@ def test_compute_breakdown( mock_classifier_trainer, mock_fit_transfrom, mock_validate_inputs, + mock_process_data_with_metadata, real_training_data, synthetic_data, real_validation_data, @@ -306,7 +310,7 @@ def test_compute_breakdown( minority_class_label = 1 classifier = 'XGBoost' fixed_recall_value = 0.9 - + mock_process_data_with_metadata.side_effect = lambda x, y, z, t: (x, y, z) real_data_baseline = { 'precision_score_training': 0.43, 'recall_score_validation': 0.7, @@ -378,6 +382,9 @@ def test_compute_breakdown( classifier, fixed_recall_value, ) + mock_process_data_with_metadata.assert_called_once_with( + real_training_data, synthetic_data, real_validation_data, metadata + ) mock_fit_transfrom.assert_called_once_with( real_training_data, synthetic_data, diff --git a/tests/unit/single_table/data_augmentation/test_utils.py b/tests/unit/single_table/data_augmentation/test_utils.py index 48b05757..3b018c93 100644 --- a/tests/unit/single_table/data_augmentation/test_utils.py +++ b/tests/unit/single_table/data_augmentation/test_utils.py @@ -1,11 +1,12 @@ import re from copy import deepcopy -from unittest.mock import patch +from unittest.mock import call, patch import pandas as pd import pytest from sdmetrics.single_table.data_augmentation.utils import ( + _process_data_with_metadata_ml_efficacy_metrics, _validate_data_and_metadata, _validate_inputs, _validate_parameters, @@ -99,10 +100,6 @@ def test__validate_data_and_metadata(): expected_message_sdtype = re.escape( 'The column `target` must be either categorical or boolean. Please update your metadata.' ) - expected_message_column_missmatch = re.escape( - '`real_training_data`, `synthetic_data` and `real_validation_data` must have the ' - 'same columns and must match the columns described in the metadata.' - ) expected_message_value = re.escape( 'The value `1` is not present in the column `target` for the real training data.' ) @@ -129,16 +126,6 @@ def test__validate_data_and_metadata(): with pytest.raises(ValueError, match=expected_message_sdtype): _validate_data_and_metadata(**wrong_inputs_sdtype) - wrong_column_metadata = deepcopy(inputs) - wrong_column_metadata['metadata']['columns'].update({'new_column': {'sdtype': 'categorical'}}) - with pytest.raises(ValueError, match=expected_message_column_missmatch): - _validate_data_and_metadata(**wrong_column_metadata) - - wrong_column_data = deepcopy(inputs) - wrong_column_data['real_training_data'] = pd.DataFrame({'new_column': [1, 0, 0]}) - with pytest.raises(ValueError, match=expected_message_column_missmatch): - _validate_data_and_metadata(**wrong_column_data) - missing_minority_class_label = deepcopy(inputs) missing_minority_class_label['real_training_data'] = pd.DataFrame({'target': [0, 0, 0]}) with pytest.raises(ValueError, match=expected_message_value): @@ -201,3 +188,43 @@ def test__validate_inputs_mock(mock_validate_data_and_metadata, mock_validate_pa classifier, fixed_recall_value, ) + + +@patch('sdmetrics.single_table.data_augmentation.utils._process_data_with_metadata') +def test__process_data_with_metadata_ml_efficacy_metrics(mock_process_data_with_metadata): + """Test the ``_process_data_with_metadata_ml_efficacy_metrics`` method.""" + # Setup + mock_process_data_with_metadata.side_effect = lambda data, metadata, x: data + real_training_data = pd.DataFrame({ + 'numerical': [1, 2, 3], + 'categorical': ['a', 'b', 'c'], + }) + synthetic_data = pd.DataFrame({ + 'numerical': [4, 5, 6], + 'categorical': ['a', 'b', 'c'], + }) + real_validation_data = pd.DataFrame({ + 'numerical': [7, 8, 9], + 'categorical': ['a', 'b', 'c'], + }) + metadata = { + 'columns': { + 'numerical': {'sdtype': 'numerical'}, + 'categorical': {'sdtype': 'categorical'}, + } + } + + # Run + result = _process_data_with_metadata_ml_efficacy_metrics( + real_training_data, synthetic_data, real_validation_data, metadata + ) + + # Assert + pd.testing.assert_frame_equal(result[0], real_training_data) + pd.testing.assert_frame_equal(result[1], synthetic_data) + pd.testing.assert_frame_equal(result[2], real_validation_data) + mock_process_data_with_metadata.assert_has_calls([ + call(real_training_data, metadata, True), + call(synthetic_data, metadata, True), + call(real_validation_data, metadata, True), + ]) diff --git a/tests/unit/single_table/test_bayesian_network.py b/tests/unit/single_table/test_bayesian_network.py index 66daf369..fc15bc58 100644 --- a/tests/unit/single_table/test_bayesian_network.py +++ b/tests/unit/single_table/test_bayesian_network.py @@ -24,7 +24,7 @@ def test_compute(self, bad_pomegranate): metric = BNLikelihood() # Act and Assert - expected_message = r'Please install pomegranate with `pip install sdmetrics\[pomegranate\]`\. Python 3\.13 is not supported\.' + expected_message = r'Please install pomegranate with `pip install sdmetrics\[pomegranate\]`\. Python 3\.13 is not supported\.' # noqa: E501 with pytest.raises(ImportError, match=expected_message): metric.compute(Mock(), Mock()) @@ -36,6 +36,8 @@ def test_compute(self, bad_pomegranate): metric = BNLogLikelihood() # Act and Assert - expected_message = r'Please install pomegranate with `pip install sdmetrics\[pomegranate\]`\. Python 3\.13 is not supported\.' + expected_message = expected_message = ( + r'Please install pomegranate with `pip install sdmetrics\[pomegranate\]`\. Python 3\.13 is not supported\.' # noqa: E501 + ) with pytest.raises(ImportError, match=expected_message): metric.compute(Mock(), Mock()) diff --git a/tests/unit/test__utils_metadata.py b/tests/unit/test__utils_metadata.py new file mode 100644 index 00000000..f94e6218 --- /dev/null +++ b/tests/unit/test__utils_metadata.py @@ -0,0 +1,324 @@ +import re +from copy import deepcopy +from unittest.mock import patch + +import pandas as pd +import pytest + +from sdmetrics._utils_metadata import ( + _convert_datetime_columns, + _process_data_with_metadata, + _remove_missing_columns_metadata, + _remove_non_modelable_columns, + _validate_metadata, + _validate_metadata_dict, + _validate_multi_table_metadata, + _validate_single_table_metadata, +) + + +@pytest.fixture +def data(): + return { + 'table1': pd.DataFrame({ + 'numerical': [1, 2, 3], + 'categorical': ['a', 'b', 'c'], + 'datetime_str': ['2021-01-01', '2021-01-02', '2021-01-03'], + 'datetime': pd.to_datetime(['2025-01-01', '2025-01-02', '2025-01-03']), + }), + 'table2': pd.DataFrame({ + 'datetime_missing_format': ['2024-01-01', '2023-01-02', '2024-01-03'], + 'extra_column_1': [1, 2, 3], + 'extra_column_2': ['a', 'b', 'c'], + }), + } + + +@pytest.fixture +def metadata(): + return { + 'tables': { + 'table1': { + 'columns': { + 'numerical': {'sdtype': 'numerical'}, + 'categorical': {'sdtype': 'categorical'}, + 'datetime_str': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'}, + 'datetime': {'sdtype': 'datetime'}, + } + }, + 'table2': { + 'columns': { + 'datetime_missing_format': {'sdtype': 'datetime'}, + } + }, + } + } + + +def test__validate_metadata_dict(metadata): + """Test the ``_validate_metadata_dict`` method.""" + # Setup + metadata_wrong = 'wrong' + expected_error = re.escape( + f"Expected a dictionary but received a '{type(metadata_wrong).__name__}' instead." + " For SDV metadata objects, please use the 'to_dict' function to convert it" + ' to a dictionary.' + ) + + # Run and Assert + _validate_metadata_dict(metadata) + with pytest.raises(TypeError, match=expected_error): + _validate_metadata_dict(metadata_wrong) + + +def test__validate_single_table_metadata(metadata): + """Test the ``_validate_single_table_metadata`` method.""" + # Setup + metadata_wrong = { + 'wrong_key': { + 'numerical': {'sdtype': 'numerical'}, + 'categorical': {'sdtype': 'categorical'}, + } + } + expected_error = re.escape( + "Single-table metadata must include a 'columns' key that maps column names" + ' to their corresponding information.' + ) + + # Run and Assert + _validate_single_table_metadata(metadata['tables']['table1']) + with pytest.raises(ValueError, match=expected_error): + _validate_single_table_metadata(metadata_wrong) + + +def test__validate_multi_table_metadata(metadata): + """Test the ``_validate_multi_table_metadata`` method.""" + # Setup + metadata_wrong = { + 'wrong_tables': { + 'table1': { + 'columns': { + 'numerical': {'sdtype': 'numerical'}, + 'categorical': {'sdtype': 'categorical'}, + } + }, + } + } + + metadata_wrong_single_table = { + 'tables': { + 'table1': { + 'columns': { + 'numerical': {'sdtype': 'numerical'}, + 'categorical': {'sdtype': 'categorical'}, + } + }, + 'table2': { + 'wrong_key': { + 'numerical': {'sdtype': 'numerical'}, + 'categorical': {'sdtype': 'categorical'}, + } + }, + } + } + expected_error = re.escape( + "Multi-table metadata must include a 'tables' key that maps table names to" + ' their respective metadata.' + ) + expected_error_single_table = re.escape( + "Error in table 'table2': Single-table metadata must include a 'columns' key" + ' that maps column names to their corresponding information.' + ) + + # Run and Assert + _validate_multi_table_metadata(metadata) + with pytest.raises(ValueError, match=expected_error): + _validate_multi_table_metadata(metadata_wrong) + + with pytest.raises(ValueError, match=expected_error_single_table): + _validate_multi_table_metadata(metadata_wrong_single_table) + + +@patch('sdmetrics._utils_metadata._validate_multi_table_metadata') +def test__validate_metadata(mock_validate_multi_table_metadata, metadata): + """Test the ``_validate_metadata`` method.""" + # Setup + wrong_metadata = {'worng_key': 'wrong_value'} + expected_error = re.escape( + "Metadata must include either a 'columns' key for single-table metadata" + " or a 'tables' key for multi-table metadata." + ) + # Run + _validate_metadata(metadata) + with pytest.raises(ValueError, match=expected_error): + _validate_metadata(wrong_metadata) + + # Assert + mock_validate_multi_table_metadata.assert_called_once_with(metadata) + + +def test__convert_datetime_columns(data, metadata): + """Test the ``_convert_datetime_columns`` method.""" + # Setup + expected_df_single_table = pd.DataFrame({ + 'numerical': [1, 2, 3], + 'categorical': ['a', 'b', 'c'], + 'datetime_str': pd.to_datetime(['2021-01-01', '2021-01-02', '2021-01-03']), + 'datetime': pd.to_datetime(['2025-01-01', '2025-01-02', '2025-01-03']), + }) + expected_result_multi_table_table = { + 'table1': expected_df_single_table, + 'table2': data['table2'], + } + + # Run + result_multi_table = _convert_datetime_columns(data, metadata) + result_single_table = _convert_datetime_columns(data['table1'], metadata['tables']['table1']) + + # Assert + for table_name, table in result_multi_table.items(): + pd.testing.assert_frame_equal(table, expected_result_multi_table_table[table_name]) + + pd.testing.assert_frame_equal(result_single_table, expected_df_single_table) + + +def test_convert_datetime_columns_with_failures(): + """Test the ``_convert_datetime_columns`` when pandas can't convert to datetime.""" + # Setup + wrong_data = pd.DataFrame({ + 'numerical': [1, 2, 3], + 'categorical': ['a', 'b', 'c'], + 'datetime_1': ['2021-01-01', '20-error', '2021-01-03'], + 'datetime_2': ['2025-01-01', '2025-01-24', '2025-13-04'], + }) + metadata = { + 'columns': { + 'numerical': {'sdtype': 'numerical'}, + 'categorical': {'sdtype': 'categorical'}, + 'datetime_1': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'}, + 'datetime_2': {'sdtype': 'datetime'}, + } + } + error_message = r"^Failed to convert column 'datetime_1' to datetime with the error:" + + # Run and Assert + with pytest.raises(ValueError, match=error_message): + _convert_datetime_columns(wrong_data, metadata) + + +def test__remove_missing_columns_metadata(data, metadata): + """Test the ``_remove_missing_columns_metadata`` method.""" + # Setup + expected_warning_missing_column_metadata = re.escape( + "The columns ('extra_column_1', 'extra_column_2') are not present in the metadata." + 'They will not be included for further evaluation.' + ) + expected_warning_extra_metadata_column = re.escape( + "The columns ('numerical') are in the metadata but they are not present in the data." + ) + data['table1'] = data['table1'].drop(columns=['numerical']) + + # Run + with pytest.warns(UserWarning, match=expected_warning_extra_metadata_column): + _remove_missing_columns_metadata(data['table1'], metadata['tables']['table1']) + + with pytest.warns(UserWarning, match=expected_warning_missing_column_metadata): + result = _remove_missing_columns_metadata(data['table2'], metadata['tables']['table2']) + + # Assert + pd.testing.assert_frame_equal( + result, data['table2'].drop(columns=['extra_column_1', 'extra_column_2']) + ) + + +def test__remove_missing_columns_metadata_with_single_table(data, metadata): + """Test the ``_remove_missing_columns_metadata`` method with a single table.""" + # Setup + expected_df_single_table = data['table2'].drop(columns=['extra_column_1', 'extra_column_2']) + expected_df_multi_table = { + 'table1': data['table1'], + 'table2': expected_df_single_table, + } + + # Run + result_single_table = _remove_missing_columns_metadata( + data['table2'], metadata['tables']['table2'] + ) + result_multi_table = _remove_missing_columns_metadata(data, metadata) + + # Assert + pd.testing.assert_frame_equal(result_single_table, expected_df_single_table) + for table_name, table in result_multi_table.items(): + pd.testing.assert_frame_equal(table, expected_df_multi_table[table_name]) + + +def test__remove_non_modelable_columns(data, metadata): + """Test the ``_remove_non_modelable_columns`` method.""" + # Setup + single_table_df = pd.DataFrame({ + 'numerical': [1, 2, 3], + 'categorical': ['a', 'b', 'c'], + 'id': [1, 2, 3], + 'boolean': [True, False, True], + 'datetime': pd.to_datetime(['2025-01-01', '2025-01-02', '2025-01-03']), + 'ssn': ['123-45-6789', '987-65-4321', '123-45-6789'], + }) + metadata_single_table = { + 'columns': { + 'numerical': {'sdtype': 'numerical'}, + 'categorical': {'sdtype': 'categorical'}, + 'id': {'sdtype': 'id'}, + 'boolean': {'sdtype': 'boolean'}, + 'datetime': {'sdtype': 'datetime'}, + 'ssn': {'sdtype': 'ssn'}, + } + } + multi_table = { + 'table1': deepcopy(single_table_df), + 'table2': data['table2'], + } + multi_table_metadata = { + 'tables': { + 'table1': metadata_single_table, + 'table2': metadata['tables']['table2'], + } + } + + # Run + result_single_table = _remove_non_modelable_columns(single_table_df, metadata_single_table) + result_multi_table = _remove_non_modelable_columns(multi_table, multi_table_metadata) + + # Assert + pd.testing.assert_frame_equal(result_single_table, single_table_df.drop(columns=['id', 'ssn'])) + pd.testing.assert_frame_equal( + result_multi_table['table1'], single_table_df.drop(columns=['id', 'ssn']) + ) + + +@patch('sdmetrics._utils_metadata._validate_metadata_dict') +@patch('sdmetrics._utils_metadata._remove_missing_columns_metadata') +@patch('sdmetrics._utils_metadata._convert_datetime_columns') +@patch('sdmetrics._utils_metadata._remove_non_modelable_columns') +def test__process_data_with_metadata( + mock_remove_non_modelable_columns, + mock_convert_datetime_columns, + mock_remove_missing_columns_metadata, + mock_validate_metadata_dict, + data, + metadata, +): + """Test the ``_process_data_with_metadata``method.""" + # Setup + mock_convert_datetime_columns.side_effect = lambda data, metadata: data + mock_remove_missing_columns_metadata.side_effect = lambda data, metadata: data + + # Run and Assert + _process_data_with_metadata(data, metadata) + + mock_convert_datetime_columns.assert_called_once_with(data, metadata) + mock_remove_missing_columns_metadata.assert_called_once_with(data, metadata) + mock_validate_metadata_dict.assert_called_once_with(metadata) + mock_remove_non_modelable_columns.assert_not_called() + + _process_data_with_metadata(data, metadata, keep_modelable_columns_only=True) + mock_remove_non_modelable_columns.assert_called_once_with(data, metadata)