From ec159d7ead6fbebd5bba8ed60669716161a1905c Mon Sep 17 00:00:00 2001 From: R-Palazzo Date: Thu, 20 Feb 2025 18:18:24 +0000 Subject: [PATCH 01/14] define utils method --- sdmetrics/_utils_metadata.py | 123 +++++++++++++++++++++++++++++++++++ 1 file changed, 123 insertions(+) create mode 100644 sdmetrics/_utils_metadata.py diff --git a/sdmetrics/_utils_metadata.py b/sdmetrics/_utils_metadata.py new file mode 100644 index 00000000..4beabc67 --- /dev/null +++ b/sdmetrics/_utils_metadata.py @@ -0,0 +1,123 @@ +import warnings + +import pandas as pd + +MODELABLE_SDTYPES = ('numerical', 'datetime', 'categorical', 'boolean') + + +def _validate_metadata_dict(metadata): + """Validate the metadata of the Data Augmentation metrics.""" + if not isinstance(metadata, dict): + raise TypeError( + f"Expected a dictionary but received a '{type(metadata).__name__}' instead." + " For SDV metadata objects, please use the 'to_dict' function to convert it" + ' to a dictionary.' + ) + + +def handle_single_and_multi_table(single_table_func): + """Decorator to handle both single and multi table functions.""" + + def wrapper(data, metadata): + if isinstance(data, pd.DataFrame): + return single_table_func(data, metadata) + + result = {} + for table_name in data: + result[table_name] = single_table_func(data[table_name], metadata['tables'][table_name]) + + return result + + return wrapper + + +@handle_single_and_multi_table +def _convert_datetime_columns(data, metadata): + """Convert datetime columns to datetime type.""" + columns_missing_datetime_format = [] + for column in metadata['columns']: + if metadata['columns'][column]['sdtype'] == 'datetime': + is_datetime = pd.api.types.is_datetime64_any_dtype(data[column]) + if not is_datetime: + datetime_format = metadata['columns'][column].get('format') + if datetime_format: + data[column] = pd.to_datetime(data[column], format=datetime_format) + else: + columns_missing_datetime_format.append(column) + data[column] = pd.to_datetime(data[column]) + + if columns_missing_datetime_format: + columns_to_print = "', '".join(columns_missing_datetime_format) + warnings.warn( + f'No `datetime_format` provided in the metadata when trying to convert the columns' + f" '{columns_to_print}' to datetime. The format will be inferred, but it may not" + ' be accurate.', + UserWarning, + ) + + return data + + +@handle_single_and_multi_table +def _remove_missing_columns_metadata(data, metadata): + """Remove columns that are not present in the metadata.""" + columns_in_metadata = set(metadata['columns'].keys()) + columns_in_data = set(data.columns) + columns_to_remove = columns_in_data - columns_in_metadata + extra_metadata_columns = columns_in_metadata - columns_in_data + if columns_to_remove: + columns_to_print = "', '".join(sorted(columns_to_remove)) + warnings.warn( + f"Some columns ('{columns_to_print}') are not present in the metadata." + 'They will not be included for further evaluation.', + UserWarning, + ) + elif extra_metadata_columns: + columns_to_print = "', '".join(sorted(extra_metadata_columns)) + warnings.warn( + f"Some columns ('{columns_to_print}') are in the metadata but they are not " + 'present in the data.', + UserWarning, + ) + + data = data.drop(columns=columns_to_remove) + column_intersection = [column for column in data.columns if column in metadata['columns']] + + return data[column_intersection] + + +@handle_single_and_multi_table +def _remove_non_modelable_columns(data, metadata): + """Remove columns that are not modelable. + + All modelable columns are numerical, datetime, categorical, or boolean sdtypes. + """ + columns_modelable = [] + for column in metadata['columns']: + column_sdtype = metadata['columns'][column]['sdtype'] + if column_sdtype in MODELABLE_SDTYPES and column in data.columns: + columns_modelable.append(column) + + return data[columns_modelable] + + +def _process_data_with_metadata(data, metadata, keep_modelable_columns_only=False): + """Process the data according to the metadata.""" + _validate_metadata_dict(metadata) + data = _convert_datetime_columns(data, metadata) + data = _remove_missing_columns_metadata(data, metadata) + if keep_modelable_columns_only: + data = _remove_non_modelable_columns(data, metadata) + + return data + + +def _process_data_with_metadata_ml_efficacy_metrics( + real_training_data, synthetic_data, real_validation_data, metadata +): + """Process the data for ML efficacy metrics according to the metadata.""" + real_training_data = _process_data_with_metadata(real_training_data, metadata, True) + synthetic_data = _process_data_with_metadata(synthetic_data, metadata, True) + real_validation_data = _process_data_with_metadata(real_validation_data, metadata, True) + + return real_training_data, synthetic_data, real_validation_data From f7afb5a0a087f9168317d4a68befea1dc8c93631 Mon Sep 17 00:00:00 2001 From: R-Palazzo Date: Thu, 20 Feb 2025 18:18:34 +0000 Subject: [PATCH 02/14] def --- sdmetrics/single_table/data_augmentation/base.py | 12 +++++++++++- sdmetrics/single_table/data_augmentation/utils.py | 12 ------------ 2 files changed, 11 insertions(+), 13 deletions(-) diff --git a/sdmetrics/single_table/data_augmentation/base.py b/sdmetrics/single_table/data_augmentation/base.py index 4016b48c..ebbb7b47 100644 --- a/sdmetrics/single_table/data_augmentation/base.py +++ b/sdmetrics/single_table/data_augmentation/base.py @@ -7,6 +7,7 @@ from sklearn.metrics import confusion_matrix, precision_recall_curve, precision_score, recall_score from xgboost import XGBClassifier +from sdmetrics._utils_metadata import _process_data_with_metadata_ml_efficacy_metrics from sdmetrics.goal import Goal from sdmetrics.single_table.base import SingleTableMetric from sdmetrics.single_table.data_augmentation.utils import _validate_inputs @@ -104,7 +105,11 @@ def _fit(cls, data, metadata, prediction_column_name): """Fit preprocessing parameters.""" discrete_columns = [] datetime_columns = [] - for column, column_meta in metadata['columns'].items(): + data_columns = data.columns + metadata_columns = metadata['columns'].keys() + common_columns = set(data_columns).intersection(metadata_columns) + for column in sorted(common_columns): + column_meta = metadata['columns'][column] if (column_meta['sdtype'] in ['categorical', 'boolean']) and ( column != prediction_column_name ): @@ -192,6 +197,11 @@ def compute_breakdown( classifier, fixed_value, ) + (real_training_data, synthetic_data, real_validation_data) = ( + _process_data_with_metadata_ml_efficacy_metrics( + real_training_data, synthetic_data, real_validation_data, metadata + ) + ) preprocessed_tables = cls._fit_transform( real_training_data, synthetic_data, diff --git a/sdmetrics/single_table/data_augmentation/utils.py b/sdmetrics/single_table/data_augmentation/utils.py index 70156e7d..e015403e 100644 --- a/sdmetrics/single_table/data_augmentation/utils.py +++ b/sdmetrics/single_table/data_augmentation/utils.py @@ -82,18 +82,6 @@ def _validate_data_and_metadata( ' Please update your metadata.' ) - columns_match = ( - set(real_training_data.columns) - == set(synthetic_data.columns) - == set(real_validation_data.columns) - ) - data_metadata_mismatch = set(metadata['columns'].keys()) != set(real_training_data.columns) - if not columns_match or data_metadata_mismatch: - raise ValueError( - '`real_training_data`, `synthetic_data` and `real_validation_data` must have ' - 'the same columns and must match the columns described in the metadata.' - ) - if minority_class_label not in real_training_data[prediction_column_name].unique(): raise ValueError( f'The value `{minority_class_label}` is not present in the column ' From 6a29e27e234a1e2ff2739dc35663e0e680d6d5e9 Mon Sep 17 00:00:00 2001 From: R-Palazzo Date: Thu, 20 Feb 2025 18:18:59 +0000 Subject: [PATCH 03/14] unit tests --- .../data_augmentation/test_base.py | 11 +- .../data_augmentation/test_utils.py | 14 - tests/unit/test__utils_metadata.py | 252 ++++++++++++++++++ 3 files changed, 261 insertions(+), 16 deletions(-) create mode 100644 tests/unit/test__utils_metadata.py diff --git a/tests/unit/single_table/data_augmentation/test_base.py b/tests/unit/single_table/data_augmentation/test_base.py index 0e919408..2ac43757 100644 --- a/tests/unit/single_table/data_augmentation/test_base.py +++ b/tests/unit/single_table/data_augmentation/test_base.py @@ -194,7 +194,7 @@ def test__fit(self, real_training_data, metadata): discrete_columns, datetime_columns = metric._fit(real_training_data, metadata, 'target') # Assert - assert discrete_columns == ['categorical', 'boolean'] + assert discrete_columns == ['boolean', 'categorical'] assert datetime_columns == ['datetime'] def test__transform(self, real_training_data, synthetic_data, real_validation_data): @@ -282,6 +282,9 @@ def test__fit_transform( for table_name, table in transformed.items(): assert table.equals(tables[table_name]) + @patch( + 'sdmetrics.single_table.data_augmentation.base._process_data_with_metadata_ml_efficacy_metrics' + ) @patch('sdmetrics.single_table.data_augmentation.base._validate_inputs') @patch( 'sdmetrics.single_table.data_augmentation.base.BaseDataAugmentationMetric._fit_transform' @@ -295,6 +298,7 @@ def test_compute_breakdown( mock_classifier_trainer, mock_fit_transfrom, mock_validate_inputs, + mock_process_data_with_metadata, real_training_data, synthetic_data, real_validation_data, @@ -306,7 +310,7 @@ def test_compute_breakdown( minority_class_label = 1 classifier = 'XGBoost' fixed_recall_value = 0.9 - + mock_process_data_with_metadata.side_effect = lambda x, y, z, t: (x, y, z) real_data_baseline = { 'precision_score_training': 0.43, 'recall_score_validation': 0.7, @@ -378,6 +382,9 @@ def test_compute_breakdown( classifier, fixed_recall_value, ) + mock_process_data_with_metadata.assert_called_once_with( + real_training_data, synthetic_data, real_validation_data, metadata + ) mock_fit_transfrom.assert_called_once_with( real_training_data, synthetic_data, diff --git a/tests/unit/single_table/data_augmentation/test_utils.py b/tests/unit/single_table/data_augmentation/test_utils.py index 48b05757..a0026d26 100644 --- a/tests/unit/single_table/data_augmentation/test_utils.py +++ b/tests/unit/single_table/data_augmentation/test_utils.py @@ -99,10 +99,6 @@ def test__validate_data_and_metadata(): expected_message_sdtype = re.escape( 'The column `target` must be either categorical or boolean. Please update your metadata.' ) - expected_message_column_missmatch = re.escape( - '`real_training_data`, `synthetic_data` and `real_validation_data` must have the ' - 'same columns and must match the columns described in the metadata.' - ) expected_message_value = re.escape( 'The value `1` is not present in the column `target` for the real training data.' ) @@ -129,16 +125,6 @@ def test__validate_data_and_metadata(): with pytest.raises(ValueError, match=expected_message_sdtype): _validate_data_and_metadata(**wrong_inputs_sdtype) - wrong_column_metadata = deepcopy(inputs) - wrong_column_metadata['metadata']['columns'].update({'new_column': {'sdtype': 'categorical'}}) - with pytest.raises(ValueError, match=expected_message_column_missmatch): - _validate_data_and_metadata(**wrong_column_metadata) - - wrong_column_data = deepcopy(inputs) - wrong_column_data['real_training_data'] = pd.DataFrame({'new_column': [1, 0, 0]}) - with pytest.raises(ValueError, match=expected_message_column_missmatch): - _validate_data_and_metadata(**wrong_column_data) - missing_minority_class_label = deepcopy(inputs) missing_minority_class_label['real_training_data'] = pd.DataFrame({'target': [0, 0, 0]}) with pytest.raises(ValueError, match=expected_message_value): diff --git a/tests/unit/test__utils_metadata.py b/tests/unit/test__utils_metadata.py new file mode 100644 index 00000000..9397e056 --- /dev/null +++ b/tests/unit/test__utils_metadata.py @@ -0,0 +1,252 @@ +import re +from copy import deepcopy +from unittest.mock import call, patch + +import pandas as pd +import pytest + +from sdmetrics._utils_metadata import ( + _convert_datetime_columns, + _process_data_with_metadata, + _process_data_with_metadata_ml_efficacy_metrics, + _remove_missing_columns_metadata, + _remove_non_modelable_columns, + _validate_metadata_dict, +) + + +@pytest.fixture +def data(): + return { + 'table1': pd.DataFrame({ + 'numerical': [1, 2, 3], + 'categorical': ['a', 'b', 'c'], + 'datetime_str': ['2021-01-01', '2021-01-02', '2021-01-03'], + 'datetime': pd.to_datetime(['2025-01-01', '2025-01-02', '2025-01-03']), + }), + 'table2': pd.DataFrame({ + 'datetime_missing_format': ['2024-01-01', '2023-01-02', '2024-01-03'], + 'extra_column_1': [1, 2, 3], + 'extra_column_2': ['a', 'b', 'c'], + }), + } + + +@pytest.fixture +def metadata(): + return { + 'tables': { + 'table1': { + 'columns': { + 'numerical': {'sdtype': 'numerical'}, + 'categorical': {'sdtype': 'categorical'}, + 'datetime_str': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'}, + 'datetime': {'sdtype': 'datetime'}, + } + }, + 'table2': { + 'columns': { + 'datetime_missing_format': {'sdtype': 'datetime'}, + } + }, + } + } + + +def test__validate_metadata_dict(metadata): + """Test the ``_validate_metadata_dict`` method.""" + # Setup + metadata_wrong = 'wrong' + expected_error = re.escape( + f"Expected a dictionary but received a '{type(metadata_wrong).__name__}' instead." + " For SDV metadata objects, please use the 'to_dict' function to convert it" + ' to a dictionary.' + ) + + # Run and Assert + _validate_metadata_dict(metadata) + with pytest.raises(TypeError, match=expected_error): + _validate_metadata_dict(metadata_wrong) + + +def test__convert_datetime_columns(data, metadata): + """Test the ``_convert_datetime_columns`` method.""" + # Setup + expected_df_single_table = pd.DataFrame({ + 'numerical': [1, 2, 3], + 'categorical': ['a', 'b', 'c'], + 'datetime_str': pd.to_datetime(['2021-01-01', '2021-01-02', '2021-01-03']), + 'datetime': pd.to_datetime(['2025-01-01', '2025-01-02', '2025-01-03']), + }) + expected_result_multi_table_table = { + 'table1': expected_df_single_table, + 'table2': data['table2'], + } + + # Run + result_multi_table = _convert_datetime_columns(data, metadata) + result_single_table = _convert_datetime_columns(data['table1'], metadata['tables']['table1']) + + # Assert + for table_name, table in result_multi_table.items(): + pd.testing.assert_frame_equal(table, expected_result_multi_table_table[table_name]) + + pd.testing.assert_frame_equal(result_single_table, expected_df_single_table) + + +def test__remove_missing_columns_metadata(data, metadata): + """Test the ``_remove_missing_columns_metadata`` method.""" + # Setup + expected_warning_missing_column_metadata = re.escape( + "Some columns ('extra_column_1', 'extra_column_2') are not present in the metadata." + 'They will not be included for further evaluation.' + ) + expected_warning_extra_metadata_column = re.escape( + "Some columns ('numerical') are in the metadata but they are not present in the data." + ) + data['table1'] = data['table1'].drop(columns=['numerical']) + + # Run + with pytest.warns(UserWarning, match=expected_warning_extra_metadata_column): + _remove_missing_columns_metadata(data['table1'], metadata['tables']['table1']) + + with pytest.warns(UserWarning, match=expected_warning_missing_column_metadata): + result = _remove_missing_columns_metadata(data['table2'], metadata['tables']['table2']) + + # Assert + pd.testing.assert_frame_equal( + result, data['table2'].drop(columns=['extra_column_1', 'extra_column_2']) + ) + + +def test__remove_missing_columns_metadata_with_single_table(data, metadata): + """Test the ``_remove_missing_columns_metadata`` method with a single table.""" + # Setup + expected_df_single_table = data['table2'].drop(columns=['extra_column_1', 'extra_column_2']) + expected_df_multi_table = { + 'table1': data['table1'], + 'table2': expected_df_single_table, + } + + # Run + result_single_table = _remove_missing_columns_metadata( + data['table2'], metadata['tables']['table2'] + ) + result_multi_table = _remove_missing_columns_metadata(data, metadata) + + # Assert + pd.testing.assert_frame_equal(result_single_table, expected_df_single_table) + for table_name, table in result_multi_table.items(): + pd.testing.assert_frame_equal(table, expected_df_multi_table[table_name]) + + +def test__remove_non_modelable_columns(data, metadata): + """Test the ``_remove_non_modelable_columns`` method.""" + # Setup + single_table_df = pd.DataFrame({ + 'numerical': [1, 2, 3], + 'categorical': ['a', 'b', 'c'], + 'id': [1, 2, 3], + 'boolean': [True, False, True], + 'datetime': pd.to_datetime(['2025-01-01', '2025-01-02', '2025-01-03']), + 'ssn': ['123-45-6789', '987-65-4321', '123-45-6789'], + }) + metadata_single_table = { + 'columns': { + 'numerical': {'sdtype': 'numerical'}, + 'categorical': {'sdtype': 'categorical'}, + 'id': {'sdtype': 'id'}, + 'boolean': {'sdtype': 'boolean'}, + 'datetime': {'sdtype': 'datetime'}, + 'ssn': {'sdtype': 'ssn'}, + } + } + multi_table = { + 'table1': deepcopy(single_table_df), + 'table2': data['table2'], + } + multi_table_metadata = { + 'tables': { + 'table1': metadata_single_table, + 'table2': metadata['tables']['table2'], + } + } + + # Run + result_single_table = _remove_non_modelable_columns(single_table_df, metadata_single_table) + result_multi_table = _remove_non_modelable_columns(multi_table, multi_table_metadata) + + # Assert + pd.testing.assert_frame_equal(result_single_table, single_table_df.drop(columns=['id', 'ssn'])) + pd.testing.assert_frame_equal( + result_multi_table['table1'], single_table_df.drop(columns=['id', 'ssn']) + ) + + +@patch('sdmetrics._utils_metadata._validate_metadata_dict') +@patch('sdmetrics._utils_metadata._remove_missing_columns_metadata') +@patch('sdmetrics._utils_metadata._convert_datetime_columns') +@patch('sdmetrics._utils_metadata._remove_non_modelable_columns') +def test__process_data_with_metadata( + mock_remove_non_modelable_columns, + mock_convert_datetime_columns, + mock_remove_missing_columns_metadata, + mock_validate_metadata_dict, + data, + metadata, +): + """Test the ``_process_data_with_metadata``method.""" + # Setup + mock_convert_datetime_columns.side_effect = lambda data, metadata: data + mock_remove_missing_columns_metadata.side_effect = lambda data, metadata: data + + # Run and Assert + _process_data_with_metadata(data, metadata) + + mock_convert_datetime_columns.assert_called_once_with(data, metadata) + mock_remove_missing_columns_metadata.assert_called_once_with(data, metadata) + mock_validate_metadata_dict.assert_called_once_with(metadata) + mock_remove_non_modelable_columns.assert_not_called() + + _process_data_with_metadata(data, metadata, keep_modelable_columns_only=True) + mock_remove_non_modelable_columns.assert_called_once_with(data, metadata) + + +@patch('sdmetrics._utils_metadata._process_data_with_metadata') +def test__process_data_with_metadata_ml_efficacy_metrics(mock_process_data_with_metadata): + """Test the ``_process_data_with_metadata_ml_efficacy_metrics`` method.""" + # Setup + mock_process_data_with_metadata.side_effect = lambda data, metadata: data + real_training_data = pd.DataFrame({ + 'numerical': [1, 2, 3], + 'categorical': ['a', 'b', 'c'], + }) + synthetic_data = pd.DataFrame({ + 'numerical': [4, 5, 6], + 'categorical': ['a', 'b', 'c'], + }) + real_validation_data = pd.DataFrame({ + 'numerical': [7, 8, 9], + 'categorical': ['a', 'b', 'c'], + }) + metadata = { + 'columns': { + 'numerical': {'sdtype': 'numerical'}, + 'categorical': {'sdtype': 'categorical'}, + } + } + + # Run + result = _process_data_with_metadata_ml_efficacy_metrics( + real_training_data, synthetic_data, real_validation_data, metadata + ) + + # Assert + pd.testing.assert_frame_equal(result[0], real_training_data) + pd.testing.assert_frame_equal(result[1], synthetic_data) + pd.testing.assert_frame_equal(result[2], real_validation_data) + mock_process_data_with_metadata.assert_has_calls([ + call(real_training_data, metadata), + call(synthetic_data, metadata), + call(real_validation_data, metadata), + ]) From 6978cc90d6a5444bb6fa40a76fe567320e077f1f Mon Sep 17 00:00:00 2001 From: R-Palazzo Date: Thu, 20 Feb 2025 18:19:10 +0000 Subject: [PATCH 04/14] integration tests --- .../test_binary_classifier_recall_efficacy.py | 41 ++++++++++++++++++- 1 file changed, 39 insertions(+), 2 deletions(-) diff --git a/tests/integration/single_table/data_augmentation/test_binary_classifier_recall_efficacy.py b/tests/integration/single_table/data_augmentation/test_binary_classifier_recall_efficacy.py index fd9e6843..713e7320 100644 --- a/tests/integration/single_table/data_augmentation/test_binary_classifier_recall_efficacy.py +++ b/tests/integration/single_table/data_augmentation/test_binary_classifier_recall_efficacy.py @@ -124,7 +124,7 @@ def test_with_nan_target_column(self): ) # Assert - assert result_breakdown['score'] in (0.5, 0.5384615384615385) + assert result_breakdown['score'] in (0.6, 0.6538461538461539) def test_with_minority_being_majority(self): """Test the metric when the minority class is the majority class.""" @@ -148,7 +148,7 @@ def test_with_minority_being_majority(self): ) # Assert - assert score == 0.46153846153846156 + assert score == 0.5 def test_with_multi_class(self): """Test the metric with multi-class classification. @@ -176,3 +176,40 @@ def test_with_multi_class(self): # Assert assert score_breakdown['score'] in (0.46153846153846156, 0.5384615384615384) + + def test_speical_data_metadata_config(self): + """Test the metric with a special data and metadata configuration. + + in this test: + - The `start_date` column is an object datetime column. + - The data has an extra column compared to the metadata (`extra_column`). + - The metadata has an extra column compared to the data (`extra_metadata_column`). + """ + # Setup + np.random.seed(0) + real_data, synthetic_data, metadata = load_demo(modality='single_table') + real_data['start_date'] = real_data['start_date'].astype(str) + synthetic_data['start_date'] = synthetic_data['start_date'].astype(str) + mask_validation = np.random.rand(len(real_data)) < 0.8 + real_training = real_data[mask_validation] + real_validation = real_data[~mask_validation] + warning_datetime = re.escape( + 'No `datetime_format` provided in the metadata when trying to convert the columns' + " 'start_date' to datetime. The format will be inferred, but it may not be accurate." + ) + + # Run + with pytest.warns(UserWarning, match=warning_datetime): + score = BinaryClassifierRecallEfficacy.compute( + real_training_data=real_training, + synthetic_data=synthetic_data, + real_validation_data=real_validation, + metadata=metadata, + prediction_column_name='gender', + minority_class_label='F', + classifier='XGBoost', + fixed_precision_value=0.8, + ) + + # Assert + assert score == 0.5 From 1c3f2a0a1f0fe7efba3630552272bd89f88504d3 Mon Sep 17 00:00:00 2001 From: R-Palazzo Date: Fri, 21 Feb 2025 12:40:47 +0000 Subject: [PATCH 05/14] fix tests --- sdmetrics/_utils_metadata.py | 2 +- ...st_binary_classifier_precision_efficacy.py | 48 +++++++++---------- .../test_binary_classifier_recall_efficacy.py | 4 +- tests/unit/test__utils_metadata.py | 8 ++-- 4 files changed, 32 insertions(+), 30 deletions(-) diff --git a/sdmetrics/_utils_metadata.py b/sdmetrics/_utils_metadata.py index 4beabc67..e0165d7b 100644 --- a/sdmetrics/_utils_metadata.py +++ b/sdmetrics/_utils_metadata.py @@ -6,7 +6,7 @@ def _validate_metadata_dict(metadata): - """Validate the metadata of the Data Augmentation metrics.""" + """Validate the metadata type.""" if not isinstance(metadata, dict): raise TypeError( f"Expected a dictionary but received a '{type(metadata).__name__}' instead." diff --git a/tests/integration/single_table/data_augmentation/test_binary_classifier_precision_efficacy.py b/tests/integration/single_table/data_augmentation/test_binary_classifier_precision_efficacy.py index aceedd1b..7459ee4a 100644 --- a/tests/integration/single_table/data_augmentation/test_binary_classifier_precision_efficacy.py +++ b/tests/integration/single_table/data_augmentation/test_binary_classifier_precision_efficacy.py @@ -44,13 +44,13 @@ def test_end_to_end(self): expected_score_breakdown = { 'real_data_baseline': { 'recall_score_training': 0.8095238095238095, - 'recall_score_validation': 0.07692307692307693, - 'precision_score_validation': 1.0, + 'recall_score_validation': 0.15384615384615385, + 'precision_score_validation': 0.4, 'prediction_counts_validation': { - 'true_positive': 1, - 'false_positive': 0, - 'true_negative': 25, - 'false_negative': 12, + 'true_positive': 2, + 'false_positive': 3, + 'true_negative': 22, + 'false_negative': 11, }, }, 'augmented_data': { @@ -59,8 +59,8 @@ def test_end_to_end(self): 'precision_score_validation': 0.0, 'prediction_counts_validation': { 'true_positive': 0, - 'false_positive': 2, - 'true_negative': 23, + 'false_positive': 1, + 'true_negative': 24, 'false_negative': 13, }, }, @@ -70,7 +70,7 @@ def test_end_to_end(self): 'classifier': 'XGBoost', 'fixed_recall_value': 0.8, }, - 'score': 0, + 'score': 0.3, } assert np.isclose( score_breakdown['real_data_baseline']['recall_score_training'], 0.8, atol=0.1 @@ -147,13 +147,13 @@ def test_with_nan_target_column(self): }, 'augmented_data': { 'recall_score_training': 0.8, - 'recall_score_validation': 0.15384615384615385, - 'precision_score_validation': 0.4, + 'recall_score_validation': 0.23076923076923078, + 'precision_score_validation': 0.6, 'prediction_counts_validation': { - 'true_positive': 2, - 'false_positive': 3, - 'true_negative': 30, - 'false_negative': 11, + 'true_positive': 3, + 'false_positive': 2, + 'true_negative': 31, + 'false_negative': 10, }, }, 'parameters': { @@ -162,7 +162,7 @@ def test_with_nan_target_column(self): 'classifier': 'XGBoost', 'fixed_recall_value': 0.8, }, - 'score': 0.48571428571428577, + 'score': 0.5857142857142857, } assert result_breakdown == expected_result @@ -188,7 +188,7 @@ def test_with_minority_being_majority(self): ) # Assert - assert score == 0 + assert score == 0.3 def test_with_multi_class(self): """Test the metric with multi-class classification. @@ -229,13 +229,13 @@ def test_with_multi_class(self): }, 'augmented_data': { 'recall_score_training': 0.8035714285714286, - 'recall_score_validation': 0.6153846153846154, - 'precision_score_validation': 0.8888888888888888, + 'recall_score_validation': 0.46153846153846156, + 'precision_score_validation': 1.0, 'prediction_counts_validation': { - 'true_positive': 8, - 'false_positive': 1, - 'true_negative': 24, - 'false_negative': 5, + 'true_positive': 6, + 'false_positive': 0, + 'true_negative': 25, + 'false_negative': 7, }, }, 'parameters': { @@ -244,6 +244,6 @@ def test_with_multi_class(self): 'classifier': 'XGBoost', 'fixed_recall_value': 0.8, }, - 'score': 0.4944444444444444, + 'score': 0.55, } assert score_breakdown == expected_score_breakdown diff --git a/tests/integration/single_table/data_augmentation/test_binary_classifier_recall_efficacy.py b/tests/integration/single_table/data_augmentation/test_binary_classifier_recall_efficacy.py index 713e7320..76a72919 100644 --- a/tests/integration/single_table/data_augmentation/test_binary_classifier_recall_efficacy.py +++ b/tests/integration/single_table/data_augmentation/test_binary_classifier_recall_efficacy.py @@ -182,12 +182,14 @@ def test_speical_data_metadata_config(self): in this test: - The `start_date` column is an object datetime column. - - The data has an extra column compared to the metadata (`extra_column`). + - The synthetic data has an extra column compared to the metadata (`extra_column`). - The metadata has an extra column compared to the data (`extra_metadata_column`). """ # Setup np.random.seed(0) real_data, synthetic_data, metadata = load_demo(modality='single_table') + metadata['columns']['extra_metadata_column'] = {'sdtype': 'categorical'} + synthetic_data['extra_column'] = 'extra' real_data['start_date'] = real_data['start_date'].astype(str) synthetic_data['start_date'] = synthetic_data['start_date'].astype(str) mask_validation = np.random.rand(len(real_data)) < 0.8 diff --git a/tests/unit/test__utils_metadata.py b/tests/unit/test__utils_metadata.py index 9397e056..d2d1dfaf 100644 --- a/tests/unit/test__utils_metadata.py +++ b/tests/unit/test__utils_metadata.py @@ -216,7 +216,7 @@ def test__process_data_with_metadata( def test__process_data_with_metadata_ml_efficacy_metrics(mock_process_data_with_metadata): """Test the ``_process_data_with_metadata_ml_efficacy_metrics`` method.""" # Setup - mock_process_data_with_metadata.side_effect = lambda data, metadata: data + mock_process_data_with_metadata.side_effect = lambda data, metadata, x: data real_training_data = pd.DataFrame({ 'numerical': [1, 2, 3], 'categorical': ['a', 'b', 'c'], @@ -246,7 +246,7 @@ def test__process_data_with_metadata_ml_efficacy_metrics(mock_process_data_with_ pd.testing.assert_frame_equal(result[1], synthetic_data) pd.testing.assert_frame_equal(result[2], real_validation_data) mock_process_data_with_metadata.assert_has_calls([ - call(real_training_data, metadata), - call(synthetic_data, metadata), - call(real_validation_data, metadata), + call(real_training_data, metadata, True), + call(synthetic_data, metadata, True), + call(real_validation_data, metadata, True), ]) From f1bccb92e2e38c0e0cb3b6b4a71e22207e60db63 Mon Sep 17 00:00:00 2001 From: R-Palazzo Date: Fri, 21 Feb 2025 13:47:50 +0000 Subject: [PATCH 06/14] propagate changes 1 --- sdmetrics/_utils_metadata.py | 31 ++++++++++++ .../multi_table/detection/parent_child.py | 5 +- sdmetrics/multi_table/multi_single_table.py | 4 +- sdmetrics/reports/base_report.py | 21 +-------- .../multi_table/base_multi_table_report.py | 16 ------- sdmetrics/single_table/base.py | 7 ++- .../single_table/data_augmentation/utils.py | 14 ++---- .../test_cardinality_statistic_similarity.py | 12 ++--- .../test_base_multi_table_report.py | 47 ------------------- tests/unit/reports/test_base_report.py | 36 ++------------ 10 files changed, 52 insertions(+), 141 deletions(-) diff --git a/sdmetrics/_utils_metadata.py b/sdmetrics/_utils_metadata.py index e0165d7b..8f9ebe95 100644 --- a/sdmetrics/_utils_metadata.py +++ b/sdmetrics/_utils_metadata.py @@ -15,6 +15,37 @@ def _validate_metadata_dict(metadata): ) +def _validate_single_table_metadata(metadata): + """Validate the metadata for a single table.""" + _validate_metadata_dict(metadata) + if 'columns' not in metadata: + raise ValueError( + 'Single table reports expect metadata to contain a "columns" key with a mapping' + ' from column names to column informations.' + ) + + +def _validate_multi_table_metadata(metadata): + """Validate the metadata for multiple tables.""" + _validate_metadata_dict(metadata) + if 'tables' not in metadata: + raise ValueError( + 'Multi table reports expect metadata to contain a "tables" key with a mapping' + ' from table names to metadata for each table.' + ) + for table_name, table_metadata in metadata['tables'].items(): + _validate_single_table_metadata(table_metadata) + + +def _validate_metadata(metadata): + """Validate the metadata.""" + _validate_metadata_dict(metadata) + if 'tables' in metadata: + _validate_multi_table_metadata(metadata) + else: + _validate_single_table_metadata(metadata) + + def handle_single_and_multi_table(single_table_func): """Decorator to handle both single and multi table functions.""" diff --git a/sdmetrics/multi_table/detection/parent_child.py b/sdmetrics/multi_table/detection/parent_child.py index b2f4a776..c423229c 100644 --- a/sdmetrics/multi_table/detection/parent_child.py +++ b/sdmetrics/multi_table/detection/parent_child.py @@ -2,6 +2,7 @@ import numpy as np +from sdmetrics._utils_metadata import _validate_multi_table_metadata from sdmetrics.multi_table.detection.base import DetectionMetric from sdmetrics.single_table.detection import LogisticDetection, SVCDetection from sdmetrics.utils import get_columns_from_metadata, nested_attrs_meta @@ -37,9 +38,7 @@ class ParentChildDetectionMetric( @staticmethod def _extract_foreign_keys(metadata): - if not isinstance(metadata, dict): - metadata = metadata.to_dict() - + _validate_multi_table_metadata(metadata) foreign_keys = [] for child_table, child_meta in metadata['tables'].items(): for child_key, field_meta in get_columns_from_metadata(child_meta).items(): diff --git a/sdmetrics/multi_table/multi_single_table.py b/sdmetrics/multi_table/multi_single_table.py index 032d7c42..2e34fc28 100644 --- a/sdmetrics/multi_table/multi_single_table.py +++ b/sdmetrics/multi_table/multi_single_table.py @@ -6,6 +6,7 @@ import numpy as np from sdmetrics import single_table +from sdmetrics._utils_metadata import _validate_multi_table_metadata from sdmetrics.errors import IncomputableMetricError from sdmetrics.multi_table.base import MultiTableMetric from sdmetrics.utils import nested_attrs_meta @@ -77,9 +78,8 @@ def _compute(self, real_data, synthetic_data, metadata=None, **kwargs): if metadata is None: metadata = {'tables': defaultdict(type(None))} - elif not isinstance(metadata, dict): - metadata = metadata.to_dict() + _validate_multi_table_metadata(metadata) scores = {} errors = {} for table_name, real_table in real_data.items(): diff --git a/sdmetrics/reports/base_report.py b/sdmetrics/reports/base_report.py index 614b8083..8fe6002d 100644 --- a/sdmetrics/reports/base_report.py +++ b/sdmetrics/reports/base_report.py @@ -13,6 +13,7 @@ import pandas as pd import tqdm +from sdmetrics._utils_metadata import _validate_metadata from sdmetrics.reports.utils import convert_datetime_columns from sdmetrics.visualization import set_plotly_config @@ -68,17 +69,6 @@ def _validate_data_format(self, real_data, synthetic_data): ) raise ValueError(error_message) - def _validate_metadata_format(self, metadata): - """Validate the metadata.""" - if not isinstance(metadata, dict): - raise TypeError('The provided metadata is not a dictionary.') - - if 'columns' not in metadata: - raise ValueError( - 'Single table reports expect metadata to contain a "columns" key with a mapping' - ' from column names to column informations.' - ) - def _validate(self, real_data, synthetic_data, metadata): """Validate the inputs. @@ -91,7 +81,7 @@ def _validate(self, real_data, synthetic_data, metadata): The metadata of the table. """ self._validate_data_format(real_data, synthetic_data) - self._validate_metadata_format(metadata) + _validate_metadata(metadata) self._validate_metadata_matches_data(real_data, synthetic_data, metadata) @staticmethod @@ -142,13 +132,6 @@ def generate(self, real_data, synthetic_data, metadata, verbose=True): verbose (bool): Whether or not to print report summary and progress. """ - if not isinstance(metadata, dict): - raise TypeError( - f"Expected a dictionary but received a '{type(metadata).__name__}' instead." - " For SDV metadata objects, please use the 'to_dict' function to convert it" - ' to a dictionary.' - ) - self._validate(real_data, synthetic_data, metadata) self.convert_datetimes(real_data, synthetic_data, metadata) diff --git a/sdmetrics/reports/multi_table/base_multi_table_report.py b/sdmetrics/reports/multi_table/base_multi_table_report.py index 1b78aa07..319c182f 100644 --- a/sdmetrics/reports/multi_table/base_multi_table_report.py +++ b/sdmetrics/reports/multi_table/base_multi_table_report.py @@ -40,22 +40,6 @@ def _validate_data_format(self, real_data, synthetic_data): raise ValueError(error_message) - def _validate_metadata_format(self, metadata): - """Validate the metadata.""" - if not isinstance(metadata, dict): - raise TypeError('The provided metadata is not a dictionary.') - - if 'tables' not in metadata: - raise ValueError( - 'Multi table reports expect metadata to contain a "tables" key with a mapping' - ' from table names to metadata for each table.' - ) - for table_name, table_metadata in metadata['tables'].items(): - if 'columns' not in table_metadata: - raise ValueError( - f'The metadata for table "{table_name}" is missing a "columns" key.' - ) - def _validate_relationships(self, real_data, synthetic_data, metadata): """Validate that the relationships are valid.""" for rel in metadata.get('relationships', []): diff --git a/sdmetrics/single_table/base.py b/sdmetrics/single_table/base.py index 32877b06..25257adc 100644 --- a/sdmetrics/single_table/base.py +++ b/sdmetrics/single_table/base.py @@ -5,6 +5,7 @@ import pandas as pd +from sdmetrics._utils_metadata import _validate_single_table_metadata from sdmetrics.base import BaseMetric from sdmetrics.errors import IncomputableMetricError from sdmetrics.utils import get_alternate_keys, get_columns_from_metadata, get_type_from_column_meta @@ -92,7 +93,7 @@ def _validate_inputs(cls, real_data, synthetic_data, metadata=None): The real data. synthetic_data(pandas.DataFrame): The synthetic data. - metadata (dict or Metadata or None): + metadata (dict): The metadata, if any. Returns: @@ -108,9 +109,7 @@ def _validate_inputs(cls, real_data, synthetic_data, metadata=None): raise ValueError('`real_data` and `synthetic_data` must have the same columns') if metadata is not None: - if not isinstance(metadata, dict): - metadata = metadata.to_dict() - + _validate_single_table_metadata(metadata) fields = get_columns_from_metadata(metadata) for column in real_data.columns: if column not in fields: diff --git a/sdmetrics/single_table/data_augmentation/utils.py b/sdmetrics/single_table/data_augmentation/utils.py index e015403e..e98b8fcd 100644 --- a/sdmetrics/single_table/data_augmentation/utils.py +++ b/sdmetrics/single_table/data_augmentation/utils.py @@ -2,6 +2,8 @@ import pandas as pd +from sdmetrics._utils_metadata import _validate_single_table_metadata + def _validate_tables(real_training_data, synthetic_data, real_validation_data): """Validate the tables of the Data Augmentation metrics.""" @@ -13,16 +15,6 @@ def _validate_tables(real_training_data, synthetic_data, real_validation_data): ) -def _validate_metadata(metadata): - """Validate the metadata of the Data Augmentation metrics.""" - if not isinstance(metadata, dict): - raise TypeError( - f"Expected a dictionary but received a '{type(metadata).__name__}' instead." - " For SDV metadata objects, please use the 'to_dict' function to convert it" - ' to a dictionary.' - ) - - def _validate_prediction_column_name(prediction_column_name): """Validate the prediction column name of the Data Augmentation metrics.""" if not isinstance(prediction_column_name, str): @@ -55,7 +47,7 @@ def _validate_parameters( ): """Validate the parameters of the Data Augmentation metrics.""" _validate_tables(real_training_data, synthetic_data, real_validation_data) - _validate_metadata(metadata) + _validate_single_table_metadata(metadata) _validate_prediction_column_name(prediction_column_name) _validate_classifier(classifier) _validate_fixed_recall_value(fixed_recall_value) diff --git a/tests/unit/multi_table/statistical/test_cardinality_statistic_similarity.py b/tests/unit/multi_table/statistical/test_cardinality_statistic_similarity.py index 74804500..f1ece8dc 100644 --- a/tests/unit/multi_table/statistical/test_cardinality_statistic_similarity.py +++ b/tests/unit/multi_table/statistical/test_cardinality_statistic_similarity.py @@ -158,14 +158,14 @@ def test_compute_breakdown(self): # Setup metadata = { 'tables': { - 'tableA': {'fields': {'col1': {}}}, + 'tableA': {'columns': {'col1': {}}}, 'tableB': { - 'fields': { + 'columns': { 'col1': {}, 'col2': {}, }, }, - 'tableC': {'fields': {'col2': {}}}, + 'tableC': {'columns': {'col2': {}}}, }, 'relationships': [ { @@ -230,9 +230,9 @@ def test_compute_breakdown_no_relationships(self): # Setup metadata = { 'tables': { - 'tableA': {'fields': {'col1': {}}}, - 'tableB': {'fields': {'col1': {}, 'col2': {}}}, - 'tableC': {'fields': {'col2': {}}}, + 'tableA': {'columns': {'col1': {}}}, + 'tableB': {'columns': {'col1': {}, 'col2': {}}}, + 'tableC': {'columns': {'col2': {}}}, }, } real_data = { diff --git a/tests/unit/reports/multi_table/test_base_multi_table_report.py b/tests/unit/reports/multi_table/test_base_multi_table_report.py index f5b50c9b..b2f308e0 100644 --- a/tests/unit/reports/multi_table/test_base_multi_table_report.py +++ b/tests/unit/reports/multi_table/test_base_multi_table_report.py @@ -41,53 +41,6 @@ def test__validate_data_format(self): with pytest.raises(ValueError, match=expected_message): base_report._validate_data_format(real_data, synthetic_data) - def test__validate_metadata_format(self): - """Test the ``_validate_metadata_format`` method. - - This test checks that the method raises an error when the metadata is not a dictionnary. - """ - # Setup - base_report = BaseMultiTableReport() - metadata = [] - - # Run and Assert - expected_message = 'The provided metadata is not a dictionary.' - with pytest.raises(TypeError, match=expected_message): - base_report._validate_metadata_format(metadata) - - def test__validate_metadata_format_with_no_tables(self): - """Test the ``_validate_metadata_format`` method. - - This test checks that the method raises an error when the metadata does not contain a - 'tables' key. - """ - # Setup - base_report = BaseMultiTableReport() - metadata = {} - - # Run and Assert - expected_message = ( - 'Multi table reports expect metadata to contain a "tables" key with a mapping from ' - 'table names to metadata for each table.' - ) - with pytest.raises(ValueError, match=expected_message): - base_report._validate_metadata_format(metadata) - - def test__validate_metadata_format_with_no_columns(self): - """Test the ``_validate_metadata_format`` method. - - This test checks that the method raises an error when the metadata does not contain a - 'columns' key. - """ - # Setup - base_report = BaseMultiTableReport() - metadata = {'tables': {'Table_1': {}}} - - # Run and Assert - expected_message = 'The metadata for table "Table_1" is missing a "columns" key.' - with pytest.raises(ValueError, match=expected_message): - base_report._validate_metadata_format(metadata) - def test__validate_relationships(self): """Test the ``_validate_relationships`` method.""" # Setup diff --git a/tests/unit/reports/test_base_report.py b/tests/unit/reports/test_base_report.py index 3a82d7f9..4e72f54e 100644 --- a/tests/unit/reports/test_base_report.py +++ b/tests/unit/reports/test_base_report.py @@ -31,38 +31,6 @@ def test__validate_data_format(self): with pytest.raises(ValueError, match=expected_message): base_report._validate_data_format(real_data, synthetic_data) - def test__validate_metadata_format(self): - """Test the ``_validate_metadata_format`` method. - - This test checks that the method raises an error when the metadata is not a dictionary. - """ - # Setup - base_report = BaseReport() - metadata = 'metadata' - - # Run and Assert - expected_message = 'The provided metadata is not a dictionary.' - with pytest.raises(TypeError, match=expected_message): - base_report._validate_metadata_format(metadata) - - def test__validate_metadata_format_no_columns(self): - """Test the ``_validate_metadata_format`` method. - - This test checks that the method raises an error when the metadata does not contain a - 'columns' key. - """ - # Setup - base_report = BaseReport() - metadata = {} - - # Run and Assert - expected_message = ( - 'Single table reports expect metadata to contain a "columns" key with a mapping' - ' from column names to column informations.' - ) - with pytest.raises(ValueError, match=expected_message): - base_report._validate_metadata_format(metadata) - def test__validate_metadata_matches_data(self): """Test the ``_validate_metadata_matches_data`` method. @@ -137,7 +105,8 @@ def test__validate_metadata_matches_data_no_mismatch(self): # Run and Assert base_report._validate_metadata_matches_data(real_data, synthetic_data, metadata) - def test__validate(self): + @patch('sdmetrics.reports.base_report._validate_metadata') + def test__validate(self, mock__validate_metadata): """Test the ``_validate`` method.""" # Setup base_report = BaseReport() @@ -165,6 +134,7 @@ def test__validate(self): base_report._validate(real_data, synthetic_data, metadata) # Assert + mock__validate_metadata.assert_called_once_with(metadata) mock__validate_metadata_matches_data.assert_called_once_with( real_data, synthetic_data, metadata ) From 14ea180d186eeb233e82b3266265b4ea0e0139a3 Mon Sep 17 00:00:00 2001 From: R-Palazzo Date: Fri, 21 Feb 2025 14:39:15 +0000 Subject: [PATCH 07/14] update err message --- sdmetrics/_utils_metadata.py | 16 +++-- .../cardinality_shape_similarity.py | 4 +- .../cardinality_statistic_similarity.py | 4 +- sdmetrics/timeseries/base.py | 5 +- tests/unit/test__utils_metadata.py | 65 +++++++++++++++++++ 5 files changed, 81 insertions(+), 13 deletions(-) diff --git a/sdmetrics/_utils_metadata.py b/sdmetrics/_utils_metadata.py index 8f9ebe95..6652c615 100644 --- a/sdmetrics/_utils_metadata.py +++ b/sdmetrics/_utils_metadata.py @@ -20,8 +20,8 @@ def _validate_single_table_metadata(metadata): _validate_metadata_dict(metadata) if 'columns' not in metadata: raise ValueError( - 'Single table reports expect metadata to contain a "columns" key with a mapping' - ' from column names to column informations.' + "Single-table metadata must include a 'columns' key that maps column names" + ' to their corresponding information.' ) @@ -30,8 +30,8 @@ def _validate_multi_table_metadata(metadata): _validate_metadata_dict(metadata) if 'tables' not in metadata: raise ValueError( - 'Multi table reports expect metadata to contain a "tables" key with a mapping' - ' from table names to metadata for each table.' + "Multi-table metadata must include a 'tables' key that maps table names" + ' to their respective metadata.' ) for table_name, table_metadata in metadata['tables'].items(): _validate_single_table_metadata(table_metadata) @@ -40,10 +40,14 @@ def _validate_multi_table_metadata(metadata): def _validate_metadata(metadata): """Validate the metadata.""" _validate_metadata_dict(metadata) + if ('columns' not in metadata) and ('tables' not in metadata): + raise ValueError( + "Metadata must include either a 'columns' key for single-table metadata" + " or a 'tables' key for multi-table metadata." + ) + if 'tables' in metadata: _validate_multi_table_metadata(metadata) - else: - _validate_single_table_metadata(metadata) def handle_single_and_multi_table(single_table_func): diff --git a/sdmetrics/multi_table/statistical/cardinality_shape_similarity.py b/sdmetrics/multi_table/statistical/cardinality_shape_similarity.py index 54c16cad..6ce0cc14 100644 --- a/sdmetrics/multi_table/statistical/cardinality_shape_similarity.py +++ b/sdmetrics/multi_table/statistical/cardinality_shape_similarity.py @@ -6,6 +6,7 @@ from sdmetrics.goal import Goal from sdmetrics.multi_table.base import MultiTableMetric from sdmetrics.utils import get_cardinality_distribution +from sdmetrics._utils_metadata import _validate_metadata, _validate_metadata_dict, _validate_single_table_metadata class CardinalityShapeSimilarity(MultiTableMetric): @@ -53,9 +54,8 @@ def compute_breakdown(cls, real_data, synthetic_data, metadata): """ if set(real_data.keys()) != set(synthetic_data.keys()): raise ValueError('`real_data` and `synthetic_data` must have the same tables.') - if not isinstance(metadata, dict): - metadata = metadata.to_dict() + _validate_metadata_dict(metadata) score_breakdowns = {} for rel in metadata.get('relationships', []): cardinality_real = get_cardinality_distribution( diff --git a/sdmetrics/multi_table/statistical/cardinality_statistic_similarity.py b/sdmetrics/multi_table/statistical/cardinality_statistic_similarity.py index 7accfbde..789f5c1f 100644 --- a/sdmetrics/multi_table/statistical/cardinality_statistic_similarity.py +++ b/sdmetrics/multi_table/statistical/cardinality_statistic_similarity.py @@ -8,6 +8,7 @@ from sdmetrics.multi_table.base import MultiTableMetric from sdmetrics.utils import get_cardinality_distribution from sdmetrics.warnings import ConstantInputWarning +from sdmetrics._utils_metadata import _validate_metadata_dict class CardinalityStatisticSimilarity(MultiTableMetric): @@ -107,9 +108,8 @@ def compute_breakdown(cls, real_data, synthetic_data, metadata=None, statistic=' raise ValueError('`real_data` and `synthetic_data` must have the same tables.') if metadata is None: raise ValueError('`metadata` cannot be ``None``.') - if not isinstance(metadata, dict): - metadata = metadata.to_dict() + _validate_metadata_dict(metadata) score_breakdowns = {} for rel in metadata.get('relationships', []): cardinality_real = get_cardinality_distribution( diff --git a/sdmetrics/timeseries/base.py b/sdmetrics/timeseries/base.py index 95c4fbce..24ce86b7 100644 --- a/sdmetrics/timeseries/base.py +++ b/sdmetrics/timeseries/base.py @@ -6,6 +6,7 @@ from sdmetrics.base import BaseMetric from sdmetrics.utils import get_columns_from_metadata +from sdmetrics._utils_metadata import _validate_metadata_dict class TimeSeriesMetric(BaseMetric): @@ -51,9 +52,7 @@ def _validate_inputs(cls, real_data, synthetic_data, metadata=None, sequence_key raise ValueError('`real_data` and `synthetic_data` must have the same columns') if metadata is not None: - if not isinstance(metadata, dict): - metadata = metadata.to_dict() - + _validate_metadata_dict(metadata) fields = get_columns_from_metadata(metadata) for column in real_data.columns: if column not in fields: diff --git a/tests/unit/test__utils_metadata.py b/tests/unit/test__utils_metadata.py index d2d1dfaf..324a0d69 100644 --- a/tests/unit/test__utils_metadata.py +++ b/tests/unit/test__utils_metadata.py @@ -11,7 +11,10 @@ _process_data_with_metadata_ml_efficacy_metrics, _remove_missing_columns_metadata, _remove_non_modelable_columns, + _validate_metadata, _validate_metadata_dict, + _validate_multi_table_metadata, + _validate_single_table_metadata, ) @@ -69,6 +72,68 @@ def test__validate_metadata_dict(metadata): _validate_metadata_dict(metadata_wrong) +def test__validate_single_table_metadata(metadata): + """Test the ``_validate_single_table_metadata`` method.""" + # Setup + metadata_wrong = { + 'wrong_key': { + 'numerical': {'sdtype': 'numerical'}, + 'categorical': {'sdtype': 'categorical'}, + } + } + expected_error = re.escape( + "Single-table metadata must include a 'columns' key that maps column names" + ' to their corresponding information.' + ) + + # Run and Assert + _validate_single_table_metadata(metadata['tables']['table1']) + with pytest.raises(ValueError, match=expected_error): + _validate_single_table_metadata(metadata_wrong) + + +def test__validate_multi_table_metadata(metadata): + """Test the ``_validate_multi_table_metadata`` method.""" + # Setup + metadata_wrong = { + 'wrong_tables': { + 'table1': { + 'columns': { + 'numerical': {'sdtype': 'numerical'}, + 'categorical': {'sdtype': 'categorical'}, + } + }, + } + } + expected_error = re.escape( + "Multi-table metadata must include a 'tables' key that maps table names to" + ' their respective metadata.' + ) + + # Run and Assert + _validate_multi_table_metadata(metadata) + with pytest.raises(ValueError, match=expected_error): + _validate_multi_table_metadata(metadata_wrong) + + +@patch('sdmetrics._utils_metadata._validate_multi_table_metadata') +def test__validate_metadata(mock_validate_multi_table_metadata, metadata): + """Test the ``_validate_metadata`` method.""" + # Setup + wrong_metadata = {'worng_key': 'wrong_value'} + expected_error = re.escape( + "Metadata must include either a 'columns' key for single-table metadata" + " or a 'tables' key for multi-table metadata." + ) + # Run + _validate_metadata(metadata) + with pytest.raises(ValueError, match=expected_error): + _validate_metadata(wrong_metadata) + + # Assert + mock_validate_multi_table_metadata.assert_called_once_with(metadata) + + def test__convert_datetime_columns(data, metadata): """Test the ``_convert_datetime_columns`` method.""" # Setup From 6793bf9f53ad59a9cf21cfbae9e28af58d486531 Mon Sep 17 00:00:00 2001 From: R-Palazzo Date: Fri, 21 Feb 2025 14:43:47 +0000 Subject: [PATCH 08/14] lint + fix minimum version --- .../statistical/cardinality_shape_similarity.py | 4 +++- .../statistical/cardinality_statistic_similarity.py | 2 +- sdmetrics/timeseries/base.py | 2 +- .../test_binary_classifier_recall_efficacy.py | 8 ++++---- 4 files changed, 9 insertions(+), 7 deletions(-) diff --git a/sdmetrics/multi_table/statistical/cardinality_shape_similarity.py b/sdmetrics/multi_table/statistical/cardinality_shape_similarity.py index 6ce0cc14..c4410248 100644 --- a/sdmetrics/multi_table/statistical/cardinality_shape_similarity.py +++ b/sdmetrics/multi_table/statistical/cardinality_shape_similarity.py @@ -3,10 +3,12 @@ import numpy as np from scipy.stats import ks_2samp +from sdmetrics._utils_metadata import ( + _validate_metadata_dict, +) from sdmetrics.goal import Goal from sdmetrics.multi_table.base import MultiTableMetric from sdmetrics.utils import get_cardinality_distribution -from sdmetrics._utils_metadata import _validate_metadata, _validate_metadata_dict, _validate_single_table_metadata class CardinalityShapeSimilarity(MultiTableMetric): diff --git a/sdmetrics/multi_table/statistical/cardinality_statistic_similarity.py b/sdmetrics/multi_table/statistical/cardinality_statistic_similarity.py index 789f5c1f..a40126f2 100644 --- a/sdmetrics/multi_table/statistical/cardinality_statistic_similarity.py +++ b/sdmetrics/multi_table/statistical/cardinality_statistic_similarity.py @@ -4,11 +4,11 @@ import numpy as np +from sdmetrics._utils_metadata import _validate_metadata_dict from sdmetrics.goal import Goal from sdmetrics.multi_table.base import MultiTableMetric from sdmetrics.utils import get_cardinality_distribution from sdmetrics.warnings import ConstantInputWarning -from sdmetrics._utils_metadata import _validate_metadata_dict class CardinalityStatisticSimilarity(MultiTableMetric): diff --git a/sdmetrics/timeseries/base.py b/sdmetrics/timeseries/base.py index 24ce86b7..6c157458 100644 --- a/sdmetrics/timeseries/base.py +++ b/sdmetrics/timeseries/base.py @@ -4,9 +4,9 @@ import pandas as pd +from sdmetrics._utils_metadata import _validate_metadata_dict from sdmetrics.base import BaseMetric from sdmetrics.utils import get_columns_from_metadata -from sdmetrics._utils_metadata import _validate_metadata_dict class TimeSeriesMetric(BaseMetric): diff --git a/tests/integration/single_table/data_augmentation/test_binary_classifier_recall_efficacy.py b/tests/integration/single_table/data_augmentation/test_binary_classifier_recall_efficacy.py index 76a72919..66c40dee 100644 --- a/tests/integration/single_table/data_augmentation/test_binary_classifier_recall_efficacy.py +++ b/tests/integration/single_table/data_augmentation/test_binary_classifier_recall_efficacy.py @@ -124,7 +124,7 @@ def test_with_nan_target_column(self): ) # Assert - assert result_breakdown['score'] in (0.6, 0.6538461538461539) + assert result_breakdown['score'] in (0.3846153846153846, 0.6538461538461539) def test_with_minority_being_majority(self): """Test the metric when the minority class is the majority class.""" @@ -148,7 +148,7 @@ def test_with_minority_being_majority(self): ) # Assert - assert score == 0.5 + assert score in (0.5, 0.3846153846153846) def test_with_multi_class(self): """Test the metric with multi-class classification. @@ -175,7 +175,7 @@ def test_with_multi_class(self): ) # Assert - assert score_breakdown['score'] in (0.46153846153846156, 0.5384615384615384) + assert score_breakdown['score'] in (0.4230769230769231, 0.5384615384615384) def test_speical_data_metadata_config(self): """Test the metric with a special data and metadata configuration. @@ -214,4 +214,4 @@ def test_speical_data_metadata_config(self): ) # Assert - assert score == 0.5 + assert score in (0.5, 0.3846153846153846) From ccb4c2d590f44b96e882d877a5dabf832bc405cf Mon Sep 17 00:00:00 2001 From: R-Palazzo Date: Mon, 24 Feb 2025 17:05:12 +0000 Subject: [PATCH 09/14] try/Except to add table name in the error message --- sdmetrics/_utils_metadata.py | 5 +++- .../single_table/test_bayesian_network.py | 10 ++++++-- tests/unit/test__utils_metadata.py | 24 +++++++++++++++++++ 3 files changed, 36 insertions(+), 3 deletions(-) diff --git a/sdmetrics/_utils_metadata.py b/sdmetrics/_utils_metadata.py index 6652c615..620f8ceb 100644 --- a/sdmetrics/_utils_metadata.py +++ b/sdmetrics/_utils_metadata.py @@ -34,7 +34,10 @@ def _validate_multi_table_metadata(metadata): ' to their respective metadata.' ) for table_name, table_metadata in metadata['tables'].items(): - _validate_single_table_metadata(table_metadata) + try: + _validate_single_table_metadata(table_metadata) + except ValueError as e: + raise ValueError(f"Error in table '{table_name}': {str(e)}") def _validate_metadata(metadata): diff --git a/tests/unit/single_table/test_bayesian_network.py b/tests/unit/single_table/test_bayesian_network.py index 66daf369..8a321099 100644 --- a/tests/unit/single_table/test_bayesian_network.py +++ b/tests/unit/single_table/test_bayesian_network.py @@ -24,7 +24,10 @@ def test_compute(self, bad_pomegranate): metric = BNLikelihood() # Act and Assert - expected_message = r'Please install pomegranate with `pip install sdmetrics\[pomegranate\]`\. Python 3\.13 is not supported\.' + expected_message = ( + r'Please install pomegranate with `pip install sdmetrics[pomegranate]`.' + r' Python 3.13 is not supported.' + ) with pytest.raises(ImportError, match=expected_message): metric.compute(Mock(), Mock()) @@ -36,6 +39,9 @@ def test_compute(self, bad_pomegranate): metric = BNLogLikelihood() # Act and Assert - expected_message = r'Please install pomegranate with `pip install sdmetrics\[pomegranate\]`\. Python 3\.13 is not supported\.' + expected_message = ( + r'Please install pomegranate with `pip install sdmetrics\[pomegranate\]`\. ' + r'Python 3\.13 is not supported\.' + ) with pytest.raises(ImportError, match=expected_message): metric.compute(Mock(), Mock()) diff --git a/tests/unit/test__utils_metadata.py b/tests/unit/test__utils_metadata.py index 324a0d69..b8d03109 100644 --- a/tests/unit/test__utils_metadata.py +++ b/tests/unit/test__utils_metadata.py @@ -105,16 +105,40 @@ def test__validate_multi_table_metadata(metadata): }, } } + + metadata_wrong_single_table = { + 'tables': { + 'table1': { + 'columns': { + 'numerical': {'sdtype': 'numerical'}, + 'categorical': {'sdtype': 'categorical'}, + } + }, + 'table2': { + 'wrong_key': { + 'numerical': {'sdtype': 'numerical'}, + 'categorical': {'sdtype': 'categorical'}, + } + }, + } + } expected_error = re.escape( "Multi-table metadata must include a 'tables' key that maps table names to" ' their respective metadata.' ) + expected_error_single_table = re.escape( + "Error in table 'table2': Single-table metadata must include a 'columns' key" + ' that maps column names to their corresponding information.' + ) # Run and Assert _validate_multi_table_metadata(metadata) with pytest.raises(ValueError, match=expected_error): _validate_multi_table_metadata(metadata_wrong) + with pytest.raises(ValueError, match=expected_error_single_table): + _validate_multi_table_metadata(metadata_wrong_single_table) + @patch('sdmetrics._utils_metadata._validate_multi_table_metadata') def test__validate_metadata(mock_validate_multi_table_metadata, metadata): From f64b9c8552db24784f42700f38ee62ae8f4c0282 Mon Sep 17 00:00:00 2001 From: R-Palazzo Date: Mon, 24 Feb 2025 17:18:40 +0000 Subject: [PATCH 10/14] move method + error message --- sdmetrics/_utils_metadata.py | 15 +----- .../single_table/data_augmentation/base.py | 6 ++- .../single_table/data_augmentation/utils.py | 13 ++++- .../data_augmentation/test_utils.py | 43 ++++++++++++++++- tests/unit/test__utils_metadata.py | 47 ++----------------- 5 files changed, 63 insertions(+), 61 deletions(-) diff --git a/sdmetrics/_utils_metadata.py b/sdmetrics/_utils_metadata.py index 620f8ceb..a1d79297 100644 --- a/sdmetrics/_utils_metadata.py +++ b/sdmetrics/_utils_metadata.py @@ -106,14 +106,14 @@ def _remove_missing_columns_metadata(data, metadata): if columns_to_remove: columns_to_print = "', '".join(sorted(columns_to_remove)) warnings.warn( - f"Some columns ('{columns_to_print}') are not present in the metadata." + f"The columns ('{columns_to_print}') are not present in the metadata." 'They will not be included for further evaluation.', UserWarning, ) elif extra_metadata_columns: columns_to_print = "', '".join(sorted(extra_metadata_columns)) warnings.warn( - f"Some columns ('{columns_to_print}') are in the metadata but they are not " + f"The columns ('{columns_to_print}') are in the metadata but they are not " 'present in the data.', UserWarning, ) @@ -148,14 +148,3 @@ def _process_data_with_metadata(data, metadata, keep_modelable_columns_only=Fals data = _remove_non_modelable_columns(data, metadata) return data - - -def _process_data_with_metadata_ml_efficacy_metrics( - real_training_data, synthetic_data, real_validation_data, metadata -): - """Process the data for ML efficacy metrics according to the metadata.""" - real_training_data = _process_data_with_metadata(real_training_data, metadata, True) - synthetic_data = _process_data_with_metadata(synthetic_data, metadata, True) - real_validation_data = _process_data_with_metadata(real_validation_data, metadata, True) - - return real_training_data, synthetic_data, real_validation_data diff --git a/sdmetrics/single_table/data_augmentation/base.py b/sdmetrics/single_table/data_augmentation/base.py index ebbb7b47..784a1112 100644 --- a/sdmetrics/single_table/data_augmentation/base.py +++ b/sdmetrics/single_table/data_augmentation/base.py @@ -7,10 +7,12 @@ from sklearn.metrics import confusion_matrix, precision_recall_curve, precision_score, recall_score from xgboost import XGBClassifier -from sdmetrics._utils_metadata import _process_data_with_metadata_ml_efficacy_metrics from sdmetrics.goal import Goal from sdmetrics.single_table.base import SingleTableMetric -from sdmetrics.single_table.data_augmentation.utils import _validate_inputs +from sdmetrics.single_table.data_augmentation.utils import ( + _process_data_with_metadata_ml_efficacy_metrics, + _validate_inputs, +) METRIC_NAME_TO_METHOD = {'recall': recall_score, 'precision': precision_score} diff --git a/sdmetrics/single_table/data_augmentation/utils.py b/sdmetrics/single_table/data_augmentation/utils.py index e98b8fcd..e2a6c172 100644 --- a/sdmetrics/single_table/data_augmentation/utils.py +++ b/sdmetrics/single_table/data_augmentation/utils.py @@ -2,7 +2,7 @@ import pandas as pd -from sdmetrics._utils_metadata import _validate_single_table_metadata +from sdmetrics._utils_metadata import _process_data_with_metadata, _validate_single_table_metadata def _validate_tables(real_training_data, synthetic_data, real_validation_data): @@ -126,3 +126,14 @@ def _validate_inputs( prediction_column_name, minority_class_label, ) + + +def _process_data_with_metadata_ml_efficacy_metrics( + real_training_data, synthetic_data, real_validation_data, metadata +): + """Process the data for ML efficacy metrics according to the metadata.""" + real_training_data = _process_data_with_metadata(real_training_data, metadata, True) + synthetic_data = _process_data_with_metadata(synthetic_data, metadata, True) + real_validation_data = _process_data_with_metadata(real_validation_data, metadata, True) + + return real_training_data, synthetic_data, real_validation_data diff --git a/tests/unit/single_table/data_augmentation/test_utils.py b/tests/unit/single_table/data_augmentation/test_utils.py index a0026d26..3b018c93 100644 --- a/tests/unit/single_table/data_augmentation/test_utils.py +++ b/tests/unit/single_table/data_augmentation/test_utils.py @@ -1,11 +1,12 @@ import re from copy import deepcopy -from unittest.mock import patch +from unittest.mock import call, patch import pandas as pd import pytest from sdmetrics.single_table.data_augmentation.utils import ( + _process_data_with_metadata_ml_efficacy_metrics, _validate_data_and_metadata, _validate_inputs, _validate_parameters, @@ -187,3 +188,43 @@ def test__validate_inputs_mock(mock_validate_data_and_metadata, mock_validate_pa classifier, fixed_recall_value, ) + + +@patch('sdmetrics.single_table.data_augmentation.utils._process_data_with_metadata') +def test__process_data_with_metadata_ml_efficacy_metrics(mock_process_data_with_metadata): + """Test the ``_process_data_with_metadata_ml_efficacy_metrics`` method.""" + # Setup + mock_process_data_with_metadata.side_effect = lambda data, metadata, x: data + real_training_data = pd.DataFrame({ + 'numerical': [1, 2, 3], + 'categorical': ['a', 'b', 'c'], + }) + synthetic_data = pd.DataFrame({ + 'numerical': [4, 5, 6], + 'categorical': ['a', 'b', 'c'], + }) + real_validation_data = pd.DataFrame({ + 'numerical': [7, 8, 9], + 'categorical': ['a', 'b', 'c'], + }) + metadata = { + 'columns': { + 'numerical': {'sdtype': 'numerical'}, + 'categorical': {'sdtype': 'categorical'}, + } + } + + # Run + result = _process_data_with_metadata_ml_efficacy_metrics( + real_training_data, synthetic_data, real_validation_data, metadata + ) + + # Assert + pd.testing.assert_frame_equal(result[0], real_training_data) + pd.testing.assert_frame_equal(result[1], synthetic_data) + pd.testing.assert_frame_equal(result[2], real_validation_data) + mock_process_data_with_metadata.assert_has_calls([ + call(real_training_data, metadata, True), + call(synthetic_data, metadata, True), + call(real_validation_data, metadata, True), + ]) diff --git a/tests/unit/test__utils_metadata.py b/tests/unit/test__utils_metadata.py index b8d03109..a5b5578d 100644 --- a/tests/unit/test__utils_metadata.py +++ b/tests/unit/test__utils_metadata.py @@ -1,6 +1,6 @@ import re from copy import deepcopy -from unittest.mock import call, patch +from unittest.mock import patch import pandas as pd import pytest @@ -8,7 +8,6 @@ from sdmetrics._utils_metadata import ( _convert_datetime_columns, _process_data_with_metadata, - _process_data_with_metadata_ml_efficacy_metrics, _remove_missing_columns_metadata, _remove_non_modelable_columns, _validate_metadata, @@ -187,11 +186,11 @@ def test__remove_missing_columns_metadata(data, metadata): """Test the ``_remove_missing_columns_metadata`` method.""" # Setup expected_warning_missing_column_metadata = re.escape( - "Some columns ('extra_column_1', 'extra_column_2') are not present in the metadata." + "The columns ('extra_column_1', 'extra_column_2') are not present in the metadata." 'They will not be included for further evaluation.' ) expected_warning_extra_metadata_column = re.escape( - "Some columns ('numerical') are in the metadata but they are not present in the data." + "The columns ('numerical') are in the metadata but they are not present in the data." ) data['table1'] = data['table1'].drop(columns=['numerical']) @@ -299,43 +298,3 @@ def test__process_data_with_metadata( _process_data_with_metadata(data, metadata, keep_modelable_columns_only=True) mock_remove_non_modelable_columns.assert_called_once_with(data, metadata) - - -@patch('sdmetrics._utils_metadata._process_data_with_metadata') -def test__process_data_with_metadata_ml_efficacy_metrics(mock_process_data_with_metadata): - """Test the ``_process_data_with_metadata_ml_efficacy_metrics`` method.""" - # Setup - mock_process_data_with_metadata.side_effect = lambda data, metadata, x: data - real_training_data = pd.DataFrame({ - 'numerical': [1, 2, 3], - 'categorical': ['a', 'b', 'c'], - }) - synthetic_data = pd.DataFrame({ - 'numerical': [4, 5, 6], - 'categorical': ['a', 'b', 'c'], - }) - real_validation_data = pd.DataFrame({ - 'numerical': [7, 8, 9], - 'categorical': ['a', 'b', 'c'], - }) - metadata = { - 'columns': { - 'numerical': {'sdtype': 'numerical'}, - 'categorical': {'sdtype': 'categorical'}, - } - } - - # Run - result = _process_data_with_metadata_ml_efficacy_metrics( - real_training_data, synthetic_data, real_validation_data, metadata - ) - - # Assert - pd.testing.assert_frame_equal(result[0], real_training_data) - pd.testing.assert_frame_equal(result[1], synthetic_data) - pd.testing.assert_frame_equal(result[2], real_validation_data) - mock_process_data_with_metadata.assert_has_calls([ - call(real_training_data, metadata, True), - call(synthetic_data, metadata, True), - call(real_validation_data, metadata, True), - ]) From 25d370488ce20cdc20ce4151ad17feb37af48916 Mon Sep 17 00:00:00 2001 From: R-Palazzo Date: Mon, 24 Feb 2025 17:28:09 +0000 Subject: [PATCH 11/14] fix unit test --- tests/unit/single_table/test_bayesian_network.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/tests/unit/single_table/test_bayesian_network.py b/tests/unit/single_table/test_bayesian_network.py index 8a321099..fc15bc58 100644 --- a/tests/unit/single_table/test_bayesian_network.py +++ b/tests/unit/single_table/test_bayesian_network.py @@ -24,10 +24,7 @@ def test_compute(self, bad_pomegranate): metric = BNLikelihood() # Act and Assert - expected_message = ( - r'Please install pomegranate with `pip install sdmetrics[pomegranate]`.' - r' Python 3.13 is not supported.' - ) + expected_message = r'Please install pomegranate with `pip install sdmetrics\[pomegranate\]`\. Python 3\.13 is not supported\.' # noqa: E501 with pytest.raises(ImportError, match=expected_message): metric.compute(Mock(), Mock()) @@ -39,9 +36,8 @@ def test_compute(self, bad_pomegranate): metric = BNLogLikelihood() # Act and Assert - expected_message = ( - r'Please install pomegranate with `pip install sdmetrics\[pomegranate\]`\. ' - r'Python 3\.13 is not supported\.' + expected_message = expected_message = ( + r'Please install pomegranate with `pip install sdmetrics\[pomegranate\]`\. Python 3\.13 is not supported\.' # noqa: E501 ) with pytest.raises(ImportError, match=expected_message): metric.compute(Mock(), Mock()) From b31129df8cb9a53e5b21355559c78811ac58016f Mon Sep 17 00:00:00 2001 From: R-Palazzo Date: Mon, 24 Feb 2025 19:04:32 +0000 Subject: [PATCH 12/14] try/catch datetime conversion --- sdmetrics/_utils_metadata.py | 22 ++++++++++++++----- tests/unit/test__utils_metadata.py | 35 ++++++++++++++++++++++++++++++ 2 files changed, 52 insertions(+), 5 deletions(-) diff --git a/sdmetrics/_utils_metadata.py b/sdmetrics/_utils_metadata.py index a1d79297..ffe3e706 100644 --- a/sdmetrics/_utils_metadata.py +++ b/sdmetrics/_utils_metadata.py @@ -73,16 +73,28 @@ def wrapper(data, metadata): def _convert_datetime_columns(data, metadata): """Convert datetime columns to datetime type.""" columns_missing_datetime_format = [] + message_conversion_failed = ( + 'Conversion to datetime failed for the following columns with errors: \n {}' + ) + columns_with_conversion_issues = [] for column in metadata['columns']: if metadata['columns'][column]['sdtype'] == 'datetime': is_datetime = pd.api.types.is_datetime64_any_dtype(data[column]) if not is_datetime: datetime_format = metadata['columns'][column].get('format') - if datetime_format: - data[column] = pd.to_datetime(data[column], format=datetime_format) - else: - columns_missing_datetime_format.append(column) - data[column] = pd.to_datetime(data[column]) + try: + if datetime_format: + data[column] = pd.to_datetime(data[column], format=datetime_format) + else: + columns_missing_datetime_format.append(column) + data[column] = pd.to_datetime(data[column]) + except Exception as e: + columns_with_conversion_issues.append(f"'{column}': {str(e)}") + + if columns_with_conversion_issues: + raise ValueError( + message_conversion_failed.format('\n'.join(columns_with_conversion_issues)) + ) if columns_missing_datetime_format: columns_to_print = "', '".join(columns_missing_datetime_format) diff --git a/tests/unit/test__utils_metadata.py b/tests/unit/test__utils_metadata.py index a5b5578d..9d147c3b 100644 --- a/tests/unit/test__utils_metadata.py +++ b/tests/unit/test__utils_metadata.py @@ -182,6 +182,41 @@ def test__convert_datetime_columns(data, metadata): pd.testing.assert_frame_equal(result_single_table, expected_df_single_table) +def test_convert_datetime_columns_with_failures(): + """Test the ``_convert_datetime_columns`` when pandas can't convert to datetime.""" + # Setup + wrong_data = pd.DataFrame({ + 'numerical': [1, 2, 3], + 'categorical': ['a', 'b', 'c'], + 'datetime_1': ['2021-01-01', '20-error', '2021-01-03'], + 'datetime_2': ['2025-01-01', '2025-01-24', '2025-13-04'], + }) + metadata = { + 'columns': { + 'numerical': {'sdtype': 'numerical'}, + 'categorical': {'sdtype': 'categorical'}, + 'datetime_1': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'}, + 'datetime_2': {'sdtype': 'datetime'}, + } + } + error_message = ( + r'\s*Conversion to datetime failed for the following columns with errors:\s*' + r"\s*'datetime_1': time data \"20-error\" doesn't match format \"%Y-%m-%d\", at " + r'position 1\.\s*You might want to try:\s*- passing `format` if your strings have a ' + r"consistent format;\s*- passing `format='ISO8601'` if your strings are all ISO8601 " + r"but not necessarily in exactly the same format;\s*- passing `format='mixed'`, and " + r'the format will be inferred for each element individually\. You might want to use ' + r"`dayfirst` alongside this\.\s*'datetime_2': time data \"2025-13-04\" doesn't match " + r'format \"%Y-%m-%d\", at position 2\.\s*You might want to try:\s*- passing `format` ' + r"if your strings have a consistent format;\s*- passing `format='ISO8601'` if your " + r'strings are all ISO8601 but not necessarily in exactly the same format;\s*- passing ' + r"`format='mixed'`, and the format will be inferred for each element individually\. " + r'You might want to use `dayfirst` alongside this\.' + ) + with pytest.raises(ValueError, match=error_message): + _convert_datetime_columns(wrong_data, metadata) + + def test__remove_missing_columns_metadata(data, metadata): """Test the ``_remove_missing_columns_metadata`` method.""" # Setup From 8bad53b9ea9aaf333dbf527e7aad978c7f22957b Mon Sep 17 00:00:00 2001 From: R-Palazzo Date: Mon, 24 Feb 2025 19:06:02 +0000 Subject: [PATCH 13/14] # Run and Assert --- tests/unit/test__utils_metadata.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/unit/test__utils_metadata.py b/tests/unit/test__utils_metadata.py index 9d147c3b..49a2a7bc 100644 --- a/tests/unit/test__utils_metadata.py +++ b/tests/unit/test__utils_metadata.py @@ -213,6 +213,8 @@ def test_convert_datetime_columns_with_failures(): r"`format='mixed'`, and the format will be inferred for each element individually\. " r'You might want to use `dayfirst` alongside this\.' ) + + # Run and Assert with pytest.raises(ValueError, match=error_message): _convert_datetime_columns(wrong_data, metadata) From 4a20f7e8f5e56af0d0b910ce0dec7af89a6175f8 Mon Sep 17 00:00:00 2001 From: R-Palazzo Date: Mon, 24 Feb 2025 19:50:19 +0000 Subject: [PATCH 14/14] exception chaining --- sdmetrics/_utils_metadata.py | 13 +++---------- tests/unit/test__utils_metadata.py | 15 +-------------- 2 files changed, 4 insertions(+), 24 deletions(-) diff --git a/sdmetrics/_utils_metadata.py b/sdmetrics/_utils_metadata.py index ffe3e706..f857a1b3 100644 --- a/sdmetrics/_utils_metadata.py +++ b/sdmetrics/_utils_metadata.py @@ -73,10 +73,6 @@ def wrapper(data, metadata): def _convert_datetime_columns(data, metadata): """Convert datetime columns to datetime type.""" columns_missing_datetime_format = [] - message_conversion_failed = ( - 'Conversion to datetime failed for the following columns with errors: \n {}' - ) - columns_with_conversion_issues = [] for column in metadata['columns']: if metadata['columns'][column]['sdtype'] == 'datetime': is_datetime = pd.api.types.is_datetime64_any_dtype(data[column]) @@ -89,12 +85,9 @@ def _convert_datetime_columns(data, metadata): columns_missing_datetime_format.append(column) data[column] = pd.to_datetime(data[column]) except Exception as e: - columns_with_conversion_issues.append(f"'{column}': {str(e)}") - - if columns_with_conversion_issues: - raise ValueError( - message_conversion_failed.format('\n'.join(columns_with_conversion_issues)) - ) + raise ValueError( + f"Failed to convert column '{column}' to datetime with the error: {str(e)}" + ) from e if columns_missing_datetime_format: columns_to_print = "', '".join(columns_missing_datetime_format) diff --git a/tests/unit/test__utils_metadata.py b/tests/unit/test__utils_metadata.py index 49a2a7bc..f94e6218 100644 --- a/tests/unit/test__utils_metadata.py +++ b/tests/unit/test__utils_metadata.py @@ -199,20 +199,7 @@ def test_convert_datetime_columns_with_failures(): 'datetime_2': {'sdtype': 'datetime'}, } } - error_message = ( - r'\s*Conversion to datetime failed for the following columns with errors:\s*' - r"\s*'datetime_1': time data \"20-error\" doesn't match format \"%Y-%m-%d\", at " - r'position 1\.\s*You might want to try:\s*- passing `format` if your strings have a ' - r"consistent format;\s*- passing `format='ISO8601'` if your strings are all ISO8601 " - r"but not necessarily in exactly the same format;\s*- passing `format='mixed'`, and " - r'the format will be inferred for each element individually\. You might want to use ' - r"`dayfirst` alongside this\.\s*'datetime_2': time data \"2025-13-04\" doesn't match " - r'format \"%Y-%m-%d\", at position 2\.\s*You might want to try:\s*- passing `format` ' - r"if your strings have a consistent format;\s*- passing `format='ISO8601'` if your " - r'strings are all ISO8601 but not necessarily in exactly the same format;\s*- passing ' - r"`format='mixed'`, and the format will be inferred for each element individually\. " - r'You might want to use `dayfirst` alongside this\.' - ) + error_message = r"^Failed to convert column 'datetime_1' to datetime with the error:" # Run and Assert with pytest.raises(ValueError, match=error_message):