From 98b7bf6c4f800e80d997862d62eaba3bd8e71465 Mon Sep 17 00:00:00 2001 From: John La Date: Fri, 21 Mar 2025 13:59:38 -0500 Subject: [PATCH 1/5] Improve handling of datetime columns for DCR metrics (#751) --- sdmetrics/_utils_metadata.py | 30 +++++----------- .../test_binary_classifier_recall_efficacy.py | 8 ++--- tests/unit/test__utils_metadata.py | 35 +++++++++++++++---- 3 files changed, 42 insertions(+), 31 deletions(-) diff --git a/sdmetrics/_utils_metadata.py b/sdmetrics/_utils_metadata.py index f857a1b3..9fe13095 100644 --- a/sdmetrics/_utils_metadata.py +++ b/sdmetrics/_utils_metadata.py @@ -72,31 +72,19 @@ def wrapper(data, metadata): @handle_single_and_multi_table def _convert_datetime_columns(data, metadata): """Convert datetime columns to datetime type.""" - columns_missing_datetime_format = [] for column in metadata['columns']: if metadata['columns'][column]['sdtype'] == 'datetime': is_datetime = pd.api.types.is_datetime64_any_dtype(data[column]) if not is_datetime: - datetime_format = metadata['columns'][column].get('format') - try: - if datetime_format: - data[column] = pd.to_datetime(data[column], format=datetime_format) - else: - columns_missing_datetime_format.append(column) - data[column] = pd.to_datetime(data[column]) - except Exception as e: + datetime_format = metadata['columns'][column].get('datetime_format') + if datetime_format: + data[column] = pd.to_datetime(data[column], format=datetime_format) + else: raise ValueError( - f"Failed to convert column '{column}' to datetime with the error: {str(e)}" - ) from e - - if columns_missing_datetime_format: - columns_to_print = "', '".join(columns_missing_datetime_format) - warnings.warn( - f'No `datetime_format` provided in the metadata when trying to convert the columns' - f" '{columns_to_print}' to datetime. The format will be inferred, but it may not" - ' be accurate.', - UserWarning, - ) + f"Datetime column '{column}' does not have a specified 'datetime_format'. " + 'Please add a the required datetime_format to the metadata or convert this column ' + "to 'pd.datetime' to bypass this requirement." + ) return data @@ -111,7 +99,7 @@ def _remove_missing_columns_metadata(data, metadata): if columns_to_remove: columns_to_print = "', '".join(sorted(columns_to_remove)) warnings.warn( - f"The columns ('{columns_to_print}') are not present in the metadata." + f"The columns ('{columns_to_print}') are not present in the metadata. " 'They will not be included for further evaluation.', UserWarning, ) diff --git a/tests/integration/single_table/data_augmentation/test_binary_classifier_recall_efficacy.py b/tests/integration/single_table/data_augmentation/test_binary_classifier_recall_efficacy.py index 60bf2a54..21ce9b1d 100644 --- a/tests/integration/single_table/data_augmentation/test_binary_classifier_recall_efficacy.py +++ b/tests/integration/single_table/data_augmentation/test_binary_classifier_recall_efficacy.py @@ -194,13 +194,13 @@ def test_speical_data_metadata_config(self): mask_validation = np.random.rand(len(real_data)) < 0.8 real_training = real_data[mask_validation] real_validation = real_data[~mask_validation] - warning_datetime = re.escape( - 'No `datetime_format` provided in the metadata when trying to convert the columns' - " 'start_date' to datetime. The format will be inferred, but it may not be accurate." + warning_extra_col = re.escape( + "The columns ('extra_column') are not present in the metadata. " + 'They will not be included for further evaluation.' ) # Run - with pytest.warns(UserWarning, match=warning_datetime): + with pytest.warns(UserWarning, match=warning_extra_col): score = BinaryClassifierRecallEfficacy.compute( real_training_data=real_training, synthetic_data=synthetic_data, diff --git a/tests/unit/test__utils_metadata.py b/tests/unit/test__utils_metadata.py index f94e6218..88b82491 100644 --- a/tests/unit/test__utils_metadata.py +++ b/tests/unit/test__utils_metadata.py @@ -172,6 +172,16 @@ def test__convert_datetime_columns(data, metadata): } # Run + error_msg = ( + "Datetime column 'datetime_missing_format' does not have a specified 'datetime_format'. " + 'Please add a the required datetime_format to the metadata or convert this column ' + "to 'pd.datetime' to bypass this requirement." + ) + with pytest.raises(ValueError, match=error_msg): + _convert_datetime_columns(data, metadata) + + table2_columns = metadata['tables']['table2']['columns'] + table2_columns['datetime_missing_format']['datetime_format'] = '%Y-%m-%d' result_multi_table = _convert_datetime_columns(data, metadata) result_single_table = _convert_datetime_columns(data['table1'], metadata['tables']['table1']) @@ -185,12 +195,13 @@ def test__convert_datetime_columns(data, metadata): def test_convert_datetime_columns_with_failures(): """Test the ``_convert_datetime_columns`` when pandas can't convert to datetime.""" # Setup - wrong_data = pd.DataFrame({ + data = pd.DataFrame({ 'numerical': [1, 2, 3], 'categorical': ['a', 'b', 'c'], 'datetime_1': ['2021-01-01', '20-error', '2021-01-03'], - 'datetime_2': ['2025-01-01', '2025-01-24', '2025-13-04'], + 'datetime_2': ['2025-01-01', '2025-01-24', '2025-01-04'], }) + metadata = { 'columns': { 'numerical': {'sdtype': 'numerical'}, @@ -199,18 +210,30 @@ def test_convert_datetime_columns_with_failures(): 'datetime_2': {'sdtype': 'datetime'}, } } - error_message = r"^Failed to convert column 'datetime_1' to datetime with the error:" # Run and Assert - with pytest.raises(ValueError, match=error_message): - _convert_datetime_columns(wrong_data, metadata) + error_msg_bad_format = 'match format' + with pytest.raises(ValueError, match=error_msg_bad_format): + _convert_datetime_columns(data, metadata) + + data['datetime_1'] = ['2021-01-01', '2021-01-02', '2021-01-03'] + + error_msg_missing_format = "does not have a specified 'datetime_format'" + with pytest.raises(ValueError, match=error_msg_missing_format): + _convert_datetime_columns(data, metadata) + + metadata['columns']['datetime_2']['datetime_format'] = '%Y-%m-%d' + + result = _convert_datetime_columns(data, metadata) + assert pd.api.types.is_datetime64_any_dtype(result['datetime_1'].dtype) + assert pd.api.types.is_datetime64_any_dtype(result['datetime_2'].dtype) def test__remove_missing_columns_metadata(data, metadata): """Test the ``_remove_missing_columns_metadata`` method.""" # Setup expected_warning_missing_column_metadata = re.escape( - "The columns ('extra_column_1', 'extra_column_2') are not present in the metadata." + "The columns ('extra_column_1', 'extra_column_2') are not present in the metadata. " 'They will not be included for further evaluation.' ) expected_warning_extra_metadata_column = re.escape( From 2ad59329784b0385e0caa268998577726259dec5 Mon Sep 17 00:00:00 2001 From: Frances Hartwell Date: Thu, 27 Mar 2025 10:17:32 -0400 Subject: [PATCH 2/5] Consolidate handling of datetime columns --- sdmetrics/_utils_metadata.py | 29 +++++---- sdmetrics/reports/base_report.py | 7 +-- .../_properties/column_pair_trends.py | 10 +-- sdmetrics/reports/utils.py | 47 -------------- .../test_base_multi_table_report.py | 5 +- tests/unit/reports/test_base_report.py | 5 +- tests/unit/reports/test_utils.py | 61 ------------------- 7 files changed, 31 insertions(+), 133 deletions(-) diff --git a/sdmetrics/_utils_metadata.py b/sdmetrics/_utils_metadata.py index 9fe13095..be16edb5 100644 --- a/sdmetrics/_utils_metadata.py +++ b/sdmetrics/_utils_metadata.py @@ -2,6 +2,8 @@ import pandas as pd +from sdmetrics.utils import is_datetime + MODELABLE_SDTYPES = ('numerical', 'datetime', 'categorical', 'boolean') @@ -69,22 +71,27 @@ def wrapper(data, metadata): return wrapper +def _convert_datetime_column(column_name, column_data, column_metadata): + if is_datetime(column_data): + return column_data + + datetime_format = column_metadata.get('datetime_format') + if datetime_format is None: + raise ValueError( + f"Datetime column '{column_name}' does not have a specified 'datetime_format'. " + 'Please add a the required datetime_format to the metadata or convert this column ' + "to 'pd.datetime' to bypass this requirement." + ) + + return pd.to_datetime(column_data, format=datetime_format) + + @handle_single_and_multi_table def _convert_datetime_columns(data, metadata): """Convert datetime columns to datetime type.""" for column in metadata['columns']: if metadata['columns'][column]['sdtype'] == 'datetime': - is_datetime = pd.api.types.is_datetime64_any_dtype(data[column]) - if not is_datetime: - datetime_format = metadata['columns'][column].get('datetime_format') - if datetime_format: - data[column] = pd.to_datetime(data[column], format=datetime_format) - else: - raise ValueError( - f"Datetime column '{column}' does not have a specified 'datetime_format'. " - 'Please add a the required datetime_format to the metadata or convert this column ' - "to 'pd.datetime' to bypass this requirement." - ) + data[column] = _convert_datetime_column(column, data[column], metadata['columns'][column]) return data diff --git a/sdmetrics/reports/base_report.py b/sdmetrics/reports/base_report.py index 8fe6002d..befcb204 100644 --- a/sdmetrics/reports/base_report.py +++ b/sdmetrics/reports/base_report.py @@ -13,8 +13,7 @@ import pandas as pd import tqdm -from sdmetrics._utils_metadata import _validate_metadata -from sdmetrics.reports.utils import convert_datetime_columns +from sdmetrics._utils_metadata import _convert_datetime_column, _validate_metadata from sdmetrics.visualization import set_plotly_config @@ -101,8 +100,8 @@ def convert_datetimes(real_data, synthetic_data, metadata): real_col = real_data[column] synth_col = synthetic_data[column] try: - converted_cols = convert_datetime_columns(real_col, synth_col, col_meta) - real_data[column], synthetic_data[column] = converted_cols + real_data[column] = _convert_datetime_column(column, real_col, col_meta) + synthetic_data[column] = _convert_datetime_column(column, synth_col, col_meta) except Exception: continue diff --git a/sdmetrics/reports/single_table/_properties/column_pair_trends.py b/sdmetrics/reports/single_table/_properties/column_pair_trends.py index 693e384a..885effd3 100644 --- a/sdmetrics/reports/single_table/_properties/column_pair_trends.py +++ b/sdmetrics/reports/single_table/_properties/column_pair_trends.py @@ -5,10 +5,10 @@ from plotly import graph_objects as go from plotly.subplots import make_subplots +from sdmetrics._utils_metadata import _convert_datetime_column from sdmetrics.column_pairs.statistical import ContingencySimilarity, CorrelationSimilarity from sdmetrics.reports.single_table._properties import BaseSingleTableProperty from sdmetrics.reports.utils import PlotConfig -from sdmetrics.utils import is_datetime DEFAULT_NUM_ROWS_SUBSAMPLE = 50000 @@ -51,13 +51,7 @@ def _convert_datetime_columns_to_numeric(self, data, metadata): col_sdtype = column_meta['sdtype'] try: if col_sdtype == 'datetime': - if not is_datetime(data[column_name]): - datetime_format = column_meta.get( - 'datetime_format', column_meta.get('format') - ) - data[column_name] = pd.to_datetime( - data[column_name], format=datetime_format - ) + data[column_name] = _convert_datetime_column(column_name, data[column_name], column_meta) nan_mask = pd.isna(data[column_name]) data[column_name] = pd.to_numeric(data[column_name]) if nan_mask.any(): diff --git a/sdmetrics/reports/utils.py b/sdmetrics/reports/utils.py index b031651a..b5c1ec53 100644 --- a/sdmetrics/reports/utils.py +++ b/sdmetrics/reports/utils.py @@ -6,14 +6,12 @@ import numpy as np import pandas as pd -from pandas.core.tools.datetimes import _guess_datetime_format_for_array from sdmetrics.utils import ( discretize_column, get_alternate_keys, get_columns_from_metadata, get_type_from_column_meta, - is_datetime, ) CONTINUOUS_SDTYPES = ['numerical', 'datetime'] @@ -35,51 +33,6 @@ class PlotConfig: FONT_SIZE = 18 -def convert_to_datetime(column_data, datetime_format=None): - """Convert a column data to pandas datetime. - - Args: - column_data (pandas.Series): - The column data - format (str): - Optional string format of datetime. If ``None``, will attempt to infer the datetime - format from the column data. Defaults to ``None``. - - Returns: - pandas.Series: - The converted column data. - """ - if is_datetime(column_data): - return column_data - - if datetime_format is None: - datetime_format = _guess_datetime_format_for_array(column_data.astype(str).to_numpy()) - - return pd.to_datetime(column_data, format=datetime_format) - - -def convert_datetime_columns(real_column, synthetic_column, col_metadata): - """Convert a real and a synthetic column to pandas datetime. - - Args: - real_data (pandas.Series): - The real column data - synthetic_column (pandas.Series): - The synthetic column data - col_metadata: - The metadata associated with the column - - Returns: - (pandas.Series, pandas.Series): - The converted real and synthetic column data. - """ - datetime_format = col_metadata.get('format') or col_metadata.get('datetime_format') - return ( - convert_to_datetime(real_column, datetime_format), - convert_to_datetime(synthetic_column, datetime_format), - ) - - def discretize_table_data(real_data, synthetic_data, metadata): """Create a copy of the real and synthetic data with discretized data. diff --git a/tests/unit/reports/multi_table/test_base_multi_table_report.py b/tests/unit/reports/multi_table/test_base_multi_table_report.py index b2f308e0..1fc6b187 100644 --- a/tests/unit/reports/multi_table/test_base_multi_table_report.py +++ b/tests/unit/reports/multi_table/test_base_multi_table_report.py @@ -209,7 +209,10 @@ def test_convert_datetimes(self): metadata = { 'tables': { 'table1': { - 'columns': {'col1': {'sdtype': 'datetime'}, 'col2': {'sdtype': 'datetime'}}, + 'columns': { + 'col1': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'}, + 'col2': {'sdtype': 'datetime'} + }, }, }, } diff --git a/tests/unit/reports/test_base_report.py b/tests/unit/reports/test_base_report.py index 4e72f54e..aa932e35 100644 --- a/tests/unit/reports/test_base_report.py +++ b/tests/unit/reports/test_base_report.py @@ -174,7 +174,10 @@ def test_convert_datetimes(self): real_data = pd.DataFrame({'col1': ['2020-01-02', '2021-01-02'], 'col2': ['a', 'b']}) synthetic_data = pd.DataFrame({'col1': ['2022-01-03', '2023-04-05'], 'col2': ['b', 'a']}) metadata = { - 'columns': {'col1': {'sdtype': 'datetime'}, 'col2': {'sdtype': 'datetime'}}, + 'columns': { + 'col1': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'}, + 'col2': {'sdtype': 'datetime'} + }, } # Run diff --git a/tests/unit/reports/test_utils.py b/tests/unit/reports/test_utils.py index 45b7ad25..c24046f4 100644 --- a/tests/unit/reports/test_utils.py +++ b/tests/unit/reports/test_utils.py @@ -5,73 +5,12 @@ from sdmetrics.reports.utils import ( aggregate_metric_results, - convert_to_datetime, discretize_and_apply_metric, discretize_table_data, ) from tests.utils import DataFrameMatcher -def test_convert_to_datetime(): - """Test the ``convert_to_datetime`` method with a datetime column. - - Expect no conversion to happen since the input is already a pandas datetime type. - - Inputs: - - datetime column - - Output: - - datetime column - """ - # Setup - column_data = pd.Series([datetime(2020, 1, 2), datetime(2021, 1, 2)]) - - # Run - out = convert_to_datetime(column_data) - - # Assert - pd.testing.assert_series_equal(out, column_data) - - -def test_convert_to_datetime_date_column(): - """Test the ``convert_to_datetime`` method with a date column. - - Expect the date column to be converted to a datetime column. - - Inputs: - - date column - - Output: - - datetime column - """ - # Setup - column_data = pd.Series([date(2020, 1, 2), date(2021, 1, 2)]) - - # Run - out = convert_to_datetime(column_data) - - # Assert - expected = pd.Series([datetime(2020, 1, 2), datetime(2021, 1, 2)]) - pd.testing.assert_series_equal(out, expected) - - -def test_convert_to_datetime_str_format(): - """Test the ``convert_to_datetime`` method with a string column. - - Expect the string date column to be converted to a datetime column - using the provided format. - """ - # Setup - column_data = pd.Series(['2020-01-02', '2021-01-02']) - - # Run - out = convert_to_datetime(column_data) - - # Assert - expected = pd.Series([datetime(2020, 1, 2), datetime(2021, 1, 2)]) - pd.testing.assert_series_equal(out, expected) - - def test_discretize_table_data(): """Test the ``discretize_table_data`` method. From 08356ba8d021b68ab2307e33816c33521678c7ae Mon Sep 17 00:00:00 2001 From: Frances Hartwell Date: Thu, 27 Mar 2025 10:48:03 -0400 Subject: [PATCH 3/5] Lint --- sdmetrics/_utils_metadata.py | 4 +++- .../reports/single_table/_properties/column_pair_trends.py | 4 +++- .../unit/reports/multi_table/test_base_multi_table_report.py | 2 +- tests/unit/reports/test_base_report.py | 2 +- 4 files changed, 8 insertions(+), 4 deletions(-) diff --git a/sdmetrics/_utils_metadata.py b/sdmetrics/_utils_metadata.py index be16edb5..98605c49 100644 --- a/sdmetrics/_utils_metadata.py +++ b/sdmetrics/_utils_metadata.py @@ -91,7 +91,9 @@ def _convert_datetime_columns(data, metadata): """Convert datetime columns to datetime type.""" for column in metadata['columns']: if metadata['columns'][column]['sdtype'] == 'datetime': - data[column] = _convert_datetime_column(column, data[column], metadata['columns'][column]) + data[column] = _convert_datetime_column( + column, data[column], metadata['columns'][column] + ) return data diff --git a/sdmetrics/reports/single_table/_properties/column_pair_trends.py b/sdmetrics/reports/single_table/_properties/column_pair_trends.py index 885effd3..9cc894b1 100644 --- a/sdmetrics/reports/single_table/_properties/column_pair_trends.py +++ b/sdmetrics/reports/single_table/_properties/column_pair_trends.py @@ -51,7 +51,9 @@ def _convert_datetime_columns_to_numeric(self, data, metadata): col_sdtype = column_meta['sdtype'] try: if col_sdtype == 'datetime': - data[column_name] = _convert_datetime_column(column_name, data[column_name], column_meta) + data[column_name] = _convert_datetime_column( + column_name, data[column_name], column_meta + ) nan_mask = pd.isna(data[column_name]) data[column_name] = pd.to_numeric(data[column_name]) if nan_mask.any(): diff --git a/tests/unit/reports/multi_table/test_base_multi_table_report.py b/tests/unit/reports/multi_table/test_base_multi_table_report.py index 1fc6b187..b13883e8 100644 --- a/tests/unit/reports/multi_table/test_base_multi_table_report.py +++ b/tests/unit/reports/multi_table/test_base_multi_table_report.py @@ -211,7 +211,7 @@ def test_convert_datetimes(self): 'table1': { 'columns': { 'col1': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'}, - 'col2': {'sdtype': 'datetime'} + 'col2': {'sdtype': 'datetime'}, }, }, }, diff --git a/tests/unit/reports/test_base_report.py b/tests/unit/reports/test_base_report.py index aa932e35..99e9199c 100644 --- a/tests/unit/reports/test_base_report.py +++ b/tests/unit/reports/test_base_report.py @@ -176,7 +176,7 @@ def test_convert_datetimes(self): metadata = { 'columns': { 'col1': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'}, - 'col2': {'sdtype': 'datetime'} + 'col2': {'sdtype': 'datetime'}, }, } From 188655ec28905fd0edaa522374b46e64367dc709 Mon Sep 17 00:00:00 2001 From: Frances Hartwell Date: Fri, 28 Mar 2025 13:04:20 -0400 Subject: [PATCH 4/5] Remove lingering pd.to_datetime and add error handling if converting with datetime_format fails --- sdmetrics/_utils_metadata.py | 5 +++ sdmetrics/reports/utils.py | 7 ++-- sdmetrics/single_table/base.py | 18 ++++---- sdmetrics/single_table/new_row_synthesis.py | 8 ++-- sdmetrics/timeseries/base.py | 24 ++++------- .../integration/timeseries/test_timeseries.py | 5 ++- tests/unit/reports/test_utils.py | 2 +- tests/unit/test__utils_metadata.py | 41 +++++++++++++++++++ tests/unit/timeseries/test_timeseries.py | 5 ++- 9 files changed, 79 insertions(+), 36 deletions(-) diff --git a/sdmetrics/_utils_metadata.py b/sdmetrics/_utils_metadata.py index 98605c49..b6570003 100644 --- a/sdmetrics/_utils_metadata.py +++ b/sdmetrics/_utils_metadata.py @@ -83,6 +83,11 @@ def _convert_datetime_column(column_name, column_data, column_metadata): "to 'pd.datetime' to bypass this requirement." ) + try: + pd.to_datetime(column_data, format=datetime_format) + except Exception as e: + raise ValueError(f"Error converting column '{column_name}' to timestamp: {e}") + return pd.to_datetime(column_data, format=datetime_format) diff --git a/sdmetrics/reports/utils.py b/sdmetrics/reports/utils.py index b5c1ec53..178e9cd0 100644 --- a/sdmetrics/reports/utils.py +++ b/sdmetrics/reports/utils.py @@ -7,6 +7,7 @@ import numpy as np import pandas as pd +from sdmetrics._utils_metadata import _convert_datetime_column from sdmetrics.utils import ( discretize_column, get_alternate_keys, @@ -62,10 +63,8 @@ def discretize_table_data(real_data, synthetic_data, metadata): real_col = real_data[column_name] synthetic_col = synthetic_data[column_name] if sdtype == 'datetime': - datetime_format = column_meta.get('format') or column_meta.get('datetime_format') - if real_col.dtype == 'O' and datetime_format: - real_col = pd.to_datetime(real_col, format=datetime_format) - synthetic_col = pd.to_datetime(synthetic_col, format=datetime_format) + real_col = _convert_datetime_column(column_name, real_col, column_meta) + synthetic_col = _convert_datetime_column(column_name, synthetic_col, column_meta) real_col = pd.to_numeric(real_col) synthetic_col = pd.to_numeric(synthetic_col) diff --git a/sdmetrics/single_table/base.py b/sdmetrics/single_table/base.py index 25257adc..598528ce 100644 --- a/sdmetrics/single_table/base.py +++ b/sdmetrics/single_table/base.py @@ -3,9 +3,7 @@ import copy from operator import attrgetter -import pandas as pd - -from sdmetrics._utils_metadata import _validate_single_table_metadata +from sdmetrics._utils_metadata import _convert_datetime_column, _validate_single_table_metadata from sdmetrics.base import BaseMetric from sdmetrics.errors import IncomputableMetricError from sdmetrics.utils import get_alternate_keys, get_columns_from_metadata, get_type_from_column_meta @@ -119,14 +117,12 @@ def _validate_inputs(cls, real_data, synthetic_data, metadata=None): field_type = get_type_from_column_meta(field_meta) if field not in real_data.columns: raise ValueError(f'Field {field} not found in data') - if ( - field_type == 'datetime' - and 'datetime_format' in field_meta - and real_data[field].dtype == 'O' - ): - dt_format = field_meta['datetime_format'] - real_data[field] = pd.to_datetime(real_data[field], format=dt_format) - synthetic_data[field] = pd.to_datetime(synthetic_data[field], format=dt_format) + + if field_type == 'datetime': + real_data[field] = _convert_datetime_column(field, real_data[field], field_meta) + synthetic_data[field] = _convert_datetime_column( + field, synthetic_data[field], field_meta + ) return real_data, synthetic_data, metadata diff --git a/sdmetrics/single_table/new_row_synthesis.py b/sdmetrics/single_table/new_row_synthesis.py index 8f9aadd6..80786cff 100644 --- a/sdmetrics/single_table/new_row_synthesis.py +++ b/sdmetrics/single_table/new_row_synthesis.py @@ -4,6 +4,7 @@ import pandas as pd +from sdmetrics._utils_metadata import _convert_datetime_column from sdmetrics.errors import IncomputableMetricError from sdmetrics.goal import Goal from sdmetrics.single_table.base import SingleTableMetric @@ -83,9 +84,10 @@ def compute_breakdown( for field, field_meta in get_columns_from_metadata(metadata).items(): if get_type_from_column_meta(field_meta) == 'datetime': - if len(real_data[field]) > 0 and isinstance(real_data[field][0], str): - real_data[field] = pd.to_datetime(real_data[field]) - synthetic_data[field] = pd.to_datetime(synthetic_data[field]) + real_data[field] = _convert_datetime_column(field, real_data[field], field_meta) + synthetic_data[field] = _convert_datetime_column( + field, synthetic_data[field], field_meta + ) real_data[field] = pd.to_numeric(real_data[field]) synthetic_data[field] = pd.to_numeric(synthetic_data[field]) diff --git a/sdmetrics/timeseries/base.py b/sdmetrics/timeseries/base.py index 6c157458..f913043f 100644 --- a/sdmetrics/timeseries/base.py +++ b/sdmetrics/timeseries/base.py @@ -2,9 +2,7 @@ from operator import attrgetter -import pandas as pd - -from sdmetrics._utils_metadata import _validate_metadata_dict +from sdmetrics._utils_metadata import _convert_datetime_column, _validate_metadata_dict from sdmetrics.base import BaseMetric from sdmetrics.utils import get_columns_from_metadata @@ -62,18 +60,14 @@ def _validate_inputs(cls, real_data, synthetic_data, metadata=None, sequence_key if field not in real_data.columns: raise ValueError(f'Field {field} not found in data') - for column, kwargs in metadata['columns'].items(): - if kwargs['sdtype'] == 'datetime': - datetime_format = kwargs.get('datetime_format') - try: - real_data[column] = pd.to_datetime( - real_data[column], format=datetime_format - ) - synthetic_data[column] = pd.to_datetime( - synthetic_data[column], format=datetime_format - ) - except ValueError: - raise ValueError(f"Column '{column}' is not a valid datetime") + for column, col_metadata in metadata['columns'].items(): + if col_metadata['sdtype'] == 'datetime': + real_data[column] = _convert_datetime_column( + column, real_data[column], col_metadata + ) + synthetic_data[column] = _convert_datetime_column( + column, synthetic_data[column], col_metadata + ) else: dtype_kinds = real_data.dtypes.apply(attrgetter('kind')) diff --git a/tests/integration/timeseries/test_timeseries.py b/tests/integration/timeseries/test_timeseries.py index b2d97acb..8d4674ba 100644 --- a/tests/integration/timeseries/test_timeseries.py +++ b/tests/integration/timeseries/test_timeseries.py @@ -76,7 +76,10 @@ def test_compute_lstmdetection_mismatching_datetime_columns(): 'visits': ['1/2/2019', '1/2/2019', '1/3/2019', '1/4/2019', '1/5/2019'], }) metadata = { - 'columns': {'s_key': {'sdtype': 'numerical'}, 'visits': {'sdtype': 'datetime'}}, + 'columns': { + 's_key': {'sdtype': 'numerical'}, + 'visits': {'sdtype': 'datetime', 'datetime_format': '%m/%d/%Y'}, + }, 'sequence_key': 's_key', } diff --git a/tests/unit/reports/test_utils.py b/tests/unit/reports/test_utils.py index c24046f4..e64e4ac6 100644 --- a/tests/unit/reports/test_utils.py +++ b/tests/unit/reports/test_utils.py @@ -47,7 +47,7 @@ def test_discretize_table_data(): 'col2': {'sdtype': 'categorical'}, 'col3': {'sdtype': 'datetime'}, 'col4': {'sdtype': 'boolean'}, - 'col5': {'sdtype': 'datetime', 'format': '%Y-%m-%d'}, + 'col5': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'}, }, } diff --git a/tests/unit/test__utils_metadata.py b/tests/unit/test__utils_metadata.py index 88b82491..421d9240 100644 --- a/tests/unit/test__utils_metadata.py +++ b/tests/unit/test__utils_metadata.py @@ -6,6 +6,7 @@ import pytest from sdmetrics._utils_metadata import ( + _convert_datetime_column, _convert_datetime_columns, _process_data_with_metadata, _remove_missing_columns_metadata, @@ -157,6 +158,46 @@ def test__validate_metadata(mock_validate_multi_table_metadata, metadata): mock_validate_multi_table_metadata.assert_called_once_with(metadata) +def test__convert_datetime_column(data, metadata): + """Test the ``_convert_datetime_column`` method.""" + # Setup + column_metadata = {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'} + str_col = pd.Series(['2021-01-01', '2021-01-02', '2021-01-03']) + datetime = pd.Series([ + pd.Timestamp('2021-01-01'), + pd.Timestamp('2021-01-02'), + pd.Timestamp('2021-01-03'), + ]) + + # Run + expected_msg = re.escape( + "Datetime column 'datetime_no_format' does not have a specified 'datetime_format'. " + 'Please add a the required datetime_format to the metadata or convert this column ' + "to 'pd.datetime' to bypass this requirement." + ) + with pytest.raises(ValueError, match=expected_msg): + _convert_datetime_column('datetime_no_format', str_col, {'sdtype': 'datetime'}) + + datetime_result = _convert_datetime_column('datetime', datetime, column_metadata) + str_result = _convert_datetime_column('datetime_str', str_col, column_metadata) + + # Assert + pd.testing.assert_series_equal(datetime, datetime_result) + pd.testing.assert_series_equal(datetime, str_result) + + +def test__convert_datetime_column_bad_format(data, metadata): + """Test the ``_convert_datetime_columns`` method when the provided format fails.""" + # Setup + column_metadata = {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'} + bad_col = pd.Series(['bad', 'datetime', 'values']) + + # Run and assert + expected_msg = re.escape("Error converting column 'datetime' to timestamp: ") + with pytest.raises(ValueError, match=expected_msg): + _convert_datetime_column('datetime', bad_col, column_metadata) + + def test__convert_datetime_columns(data, metadata): """Test the ``_convert_datetime_columns`` method.""" # Setup diff --git a/tests/unit/timeseries/test_timeseries.py b/tests/unit/timeseries/test_timeseries.py index 3dc22679..9240ead8 100644 --- a/tests/unit/timeseries/test_timeseries.py +++ b/tests/unit/timeseries/test_timeseries.py @@ -1,3 +1,5 @@ +import re + import pandas as pd import pytest @@ -25,7 +27,8 @@ def test__validate_inputs_for_TimeSeriesMetric(): } # Run and Assert - with pytest.raises(ValueError, match="Column 'visits' is not a valid datetime"): + expected_msg = re.escape("Error converting column 'visits' to timestamp: ") + with pytest.raises(ValueError, match=expected_msg): TimeSeriesMetric._validate_inputs( real_data=df1, synthetic_data=df2, sequence_key=['s_key'], metadata=metadata ) From c8c1bdcdd81fb7a601f8746437879cce5a5077bf Mon Sep 17 00:00:00 2001 From: Frances Hartwell Date: Fri, 28 Mar 2025 14:13:50 -0400 Subject: [PATCH 5/5] Fix minimum test --- tests/integration/timeseries/test_timeseries.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/integration/timeseries/test_timeseries.py b/tests/integration/timeseries/test_timeseries.py index 8d4674ba..fb7847c8 100644 --- a/tests/integration/timeseries/test_timeseries.py +++ b/tests/integration/timeseries/test_timeseries.py @@ -63,14 +63,13 @@ def test_compute_lstmdetection_multiple_categorical_columns(): def test_compute_lstmdetection_mismatching_datetime_columns(): """Test LSTMDetection metric with mismatching datetime columns. - Test it when the real data has a date column and the synthetic data has a string column. + Test it when the real data has a datetime column and the synthetic data has a string column. """ # Setup df1 = pd.DataFrame({ 's_key': [1, 2, 3, 4, 5], 'visits': pd.to_datetime(['1/1/2019', '1/2/2019', '1/3/2019', '1/4/2019', '1/5/2019']), }) - df1['visits'] = df1['visits'].dt.date df2 = pd.DataFrame({ 's_key': [1, 2, 3, 4, 5], 'visits': ['1/2/2019', '1/2/2019', '1/3/2019', '1/4/2019', '1/5/2019'],