Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 26 additions & 24 deletions sdmetrics/_utils_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import pandas as pd

from sdmetrics.utils import is_datetime

MODELABLE_SDTYPES = ('numerical', 'datetime', 'categorical', 'boolean')


Expand Down Expand Up @@ -69,34 +71,34 @@ def wrapper(data, metadata):
return wrapper


def _convert_datetime_column(column_name, column_data, column_metadata):
if is_datetime(column_data):
return column_data

datetime_format = column_metadata.get('datetime_format')
if datetime_format is None:
raise ValueError(
f"Datetime column '{column_name}' does not have a specified 'datetime_format'. "
'Please add a the required datetime_format to the metadata or convert this column '
"to 'pd.datetime' to bypass this requirement."
)

try:
pd.to_datetime(column_data, format=datetime_format)
except Exception as e:
raise ValueError(f"Error converting column '{column_name}' to timestamp: {e}")

return pd.to_datetime(column_data, format=datetime_format)


@handle_single_and_multi_table
def _convert_datetime_columns(data, metadata):
"""Convert datetime columns to datetime type."""
columns_missing_datetime_format = []
for column in metadata['columns']:
if metadata['columns'][column]['sdtype'] == 'datetime':
is_datetime = pd.api.types.is_datetime64_any_dtype(data[column])
if not is_datetime:
datetime_format = metadata['columns'][column].get('format')
try:
if datetime_format:
data[column] = pd.to_datetime(data[column], format=datetime_format)
else:
columns_missing_datetime_format.append(column)
data[column] = pd.to_datetime(data[column])
except Exception as e:
raise ValueError(
f"Failed to convert column '{column}' to datetime with the error: {str(e)}"
) from e

if columns_missing_datetime_format:
columns_to_print = "', '".join(columns_missing_datetime_format)
warnings.warn(
f'No `datetime_format` provided in the metadata when trying to convert the columns'
f" '{columns_to_print}' to datetime. The format will be inferred, but it may not"
' be accurate.',
UserWarning,
)
data[column] = _convert_datetime_column(
column, data[column], metadata['columns'][column]
)

return data

Expand All @@ -111,7 +113,7 @@ def _remove_missing_columns_metadata(data, metadata):
if columns_to_remove:
columns_to_print = "', '".join(sorted(columns_to_remove))
warnings.warn(
f"The columns ('{columns_to_print}') are not present in the metadata."
f"The columns ('{columns_to_print}') are not present in the metadata. "
'They will not be included for further evaluation.',
UserWarning,
)
Expand Down
7 changes: 3 additions & 4 deletions sdmetrics/reports/base_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
import pandas as pd
import tqdm

from sdmetrics._utils_metadata import _validate_metadata
from sdmetrics.reports.utils import convert_datetime_columns
from sdmetrics._utils_metadata import _convert_datetime_column, _validate_metadata
from sdmetrics.visualization import set_plotly_config


Expand Down Expand Up @@ -101,8 +100,8 @@ def convert_datetimes(real_data, synthetic_data, metadata):
real_col = real_data[column]
synth_col = synthetic_data[column]
try:
converted_cols = convert_datetime_columns(real_col, synth_col, col_meta)
real_data[column], synthetic_data[column] = converted_cols
real_data[column] = _convert_datetime_column(column, real_col, col_meta)
synthetic_data[column] = _convert_datetime_column(column, synth_col, col_meta)
except Exception:
continue

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
from plotly import graph_objects as go
from plotly.subplots import make_subplots

from sdmetrics._utils_metadata import _convert_datetime_column
from sdmetrics.column_pairs.statistical import ContingencySimilarity, CorrelationSimilarity
from sdmetrics.reports.single_table._properties import BaseSingleTableProperty
from sdmetrics.reports.utils import PlotConfig
from sdmetrics.utils import is_datetime

DEFAULT_NUM_ROWS_SUBSAMPLE = 50000

Expand Down Expand Up @@ -51,13 +51,9 @@ def _convert_datetime_columns_to_numeric(self, data, metadata):
col_sdtype = column_meta['sdtype']
try:
if col_sdtype == 'datetime':
if not is_datetime(data[column_name]):
datetime_format = column_meta.get(
'datetime_format', column_meta.get('format')
)
data[column_name] = pd.to_datetime(
data[column_name], format=datetime_format
)
data[column_name] = _convert_datetime_column(
column_name, data[column_name], column_meta
)
nan_mask = pd.isna(data[column_name])
data[column_name] = pd.to_numeric(data[column_name])
if nan_mask.any():
Expand Down
54 changes: 3 additions & 51 deletions sdmetrics/reports/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,13 @@

import numpy as np
import pandas as pd
from pandas.core.tools.datetimes import _guess_datetime_format_for_array

from sdmetrics._utils_metadata import _convert_datetime_column
from sdmetrics.utils import (
discretize_column,
get_alternate_keys,
get_columns_from_metadata,
get_type_from_column_meta,
is_datetime,
)

CONTINUOUS_SDTYPES = ['numerical', 'datetime']
Expand All @@ -35,51 +34,6 @@ class PlotConfig:
FONT_SIZE = 18


def convert_to_datetime(column_data, datetime_format=None):
"""Convert a column data to pandas datetime.

Args:
column_data (pandas.Series):
The column data
format (str):
Optional string format of datetime. If ``None``, will attempt to infer the datetime
format from the column data. Defaults to ``None``.

Returns:
pandas.Series:
The converted column data.
"""
if is_datetime(column_data):
return column_data

if datetime_format is None:
datetime_format = _guess_datetime_format_for_array(column_data.astype(str).to_numpy())

return pd.to_datetime(column_data, format=datetime_format)


def convert_datetime_columns(real_column, synthetic_column, col_metadata):
"""Convert a real and a synthetic column to pandas datetime.

Args:
real_data (pandas.Series):
The real column data
synthetic_column (pandas.Series):
The synthetic column data
col_metadata:
The metadata associated with the column

Returns:
(pandas.Series, pandas.Series):
The converted real and synthetic column data.
"""
datetime_format = col_metadata.get('format') or col_metadata.get('datetime_format')
return (
convert_to_datetime(real_column, datetime_format),
convert_to_datetime(synthetic_column, datetime_format),
)


def discretize_table_data(real_data, synthetic_data, metadata):
"""Create a copy of the real and synthetic data with discretized data.

Expand Down Expand Up @@ -109,10 +63,8 @@ def discretize_table_data(real_data, synthetic_data, metadata):
real_col = real_data[column_name]
synthetic_col = synthetic_data[column_name]
if sdtype == 'datetime':
datetime_format = column_meta.get('format') or column_meta.get('datetime_format')
if real_col.dtype == 'O' and datetime_format:
real_col = pd.to_datetime(real_col, format=datetime_format)
synthetic_col = pd.to_datetime(synthetic_col, format=datetime_format)
real_col = _convert_datetime_column(column_name, real_col, column_meta)
synthetic_col = _convert_datetime_column(column_name, synthetic_col, column_meta)

real_col = pd.to_numeric(real_col)
synthetic_col = pd.to_numeric(synthetic_col)
Expand Down
18 changes: 7 additions & 11 deletions sdmetrics/single_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@
import copy
from operator import attrgetter

import pandas as pd

from sdmetrics._utils_metadata import _validate_single_table_metadata
from sdmetrics._utils_metadata import _convert_datetime_column, _validate_single_table_metadata
from sdmetrics.base import BaseMetric
from sdmetrics.errors import IncomputableMetricError
from sdmetrics.utils import get_alternate_keys, get_columns_from_metadata, get_type_from_column_meta
Expand Down Expand Up @@ -119,14 +117,12 @@ def _validate_inputs(cls, real_data, synthetic_data, metadata=None):
field_type = get_type_from_column_meta(field_meta)
if field not in real_data.columns:
raise ValueError(f'Field {field} not found in data')
if (
field_type == 'datetime'
and 'datetime_format' in field_meta
and real_data[field].dtype == 'O'
):
dt_format = field_meta['datetime_format']
real_data[field] = pd.to_datetime(real_data[field], format=dt_format)
synthetic_data[field] = pd.to_datetime(synthetic_data[field], format=dt_format)

if field_type == 'datetime':
real_data[field] = _convert_datetime_column(field, real_data[field], field_meta)
synthetic_data[field] = _convert_datetime_column(
field, synthetic_data[field], field_meta
)

return real_data, synthetic_data, metadata

Expand Down
8 changes: 5 additions & 3 deletions sdmetrics/single_table/new_row_synthesis.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import pandas as pd

from sdmetrics._utils_metadata import _convert_datetime_column
from sdmetrics.errors import IncomputableMetricError
from sdmetrics.goal import Goal
from sdmetrics.single_table.base import SingleTableMetric
Expand Down Expand Up @@ -83,9 +84,10 @@ def compute_breakdown(

for field, field_meta in get_columns_from_metadata(metadata).items():
if get_type_from_column_meta(field_meta) == 'datetime':
if len(real_data[field]) > 0 and isinstance(real_data[field][0], str):
real_data[field] = pd.to_datetime(real_data[field])
synthetic_data[field] = pd.to_datetime(synthetic_data[field])
real_data[field] = _convert_datetime_column(field, real_data[field], field_meta)
synthetic_data[field] = _convert_datetime_column(
field, synthetic_data[field], field_meta
)

real_data[field] = pd.to_numeric(real_data[field])
synthetic_data[field] = pd.to_numeric(synthetic_data[field])
Expand Down
24 changes: 9 additions & 15 deletions sdmetrics/timeseries/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@

from operator import attrgetter

import pandas as pd

from sdmetrics._utils_metadata import _validate_metadata_dict
from sdmetrics._utils_metadata import _convert_datetime_column, _validate_metadata_dict
from sdmetrics.base import BaseMetric
from sdmetrics.utils import get_columns_from_metadata

Expand Down Expand Up @@ -62,18 +60,14 @@ def _validate_inputs(cls, real_data, synthetic_data, metadata=None, sequence_key
if field not in real_data.columns:
raise ValueError(f'Field {field} not found in data')

for column, kwargs in metadata['columns'].items():
if kwargs['sdtype'] == 'datetime':
datetime_format = kwargs.get('datetime_format')
try:
real_data[column] = pd.to_datetime(
real_data[column], format=datetime_format
)
synthetic_data[column] = pd.to_datetime(
synthetic_data[column], format=datetime_format
)
except ValueError:
raise ValueError(f"Column '{column}' is not a valid datetime")
for column, col_metadata in metadata['columns'].items():
if col_metadata['sdtype'] == 'datetime':
real_data[column] = _convert_datetime_column(
column, real_data[column], col_metadata
)
synthetic_data[column] = _convert_datetime_column(
column, synthetic_data[column], col_metadata
)

else:
dtype_kinds = real_data.dtypes.apply(attrgetter('kind'))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,13 +194,13 @@ def test_speical_data_metadata_config(self):
mask_validation = np.random.rand(len(real_data)) < 0.8
real_training = real_data[mask_validation]
real_validation = real_data[~mask_validation]
warning_datetime = re.escape(
'No `datetime_format` provided in the metadata when trying to convert the columns'
" 'start_date' to datetime. The format will be inferred, but it may not be accurate."
warning_extra_col = re.escape(
"The columns ('extra_column') are not present in the metadata. "
'They will not be included for further evaluation.'
)

# Run
with pytest.warns(UserWarning, match=warning_datetime):
with pytest.warns(UserWarning, match=warning_extra_col):
score = BinaryClassifierRecallEfficacy.compute(
real_training_data=real_training,
synthetic_data=synthetic_data,
Expand Down
8 changes: 5 additions & 3 deletions tests/integration/timeseries/test_timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,20 +63,22 @@ def test_compute_lstmdetection_multiple_categorical_columns():
def test_compute_lstmdetection_mismatching_datetime_columns():
"""Test LSTMDetection metric with mismatching datetime columns.

Test it when the real data has a date column and the synthetic data has a string column.
Test it when the real data has a datetime column and the synthetic data has a string column.
"""
# Setup
df1 = pd.DataFrame({
's_key': [1, 2, 3, 4, 5],
'visits': pd.to_datetime(['1/1/2019', '1/2/2019', '1/3/2019', '1/4/2019', '1/5/2019']),
})
df1['visits'] = df1['visits'].dt.date
df2 = pd.DataFrame({
's_key': [1, 2, 3, 4, 5],
'visits': ['1/2/2019', '1/2/2019', '1/3/2019', '1/4/2019', '1/5/2019'],
})
metadata = {
'columns': {'s_key': {'sdtype': 'numerical'}, 'visits': {'sdtype': 'datetime'}},
'columns': {
's_key': {'sdtype': 'numerical'},
'visits': {'sdtype': 'datetime', 'datetime_format': '%m/%d/%Y'},
},
'sequence_key': 's_key',
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,10 @@ def test_convert_datetimes(self):
metadata = {
'tables': {
'table1': {
'columns': {'col1': {'sdtype': 'datetime'}, 'col2': {'sdtype': 'datetime'}},
'columns': {
'col1': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'},
'col2': {'sdtype': 'datetime'},
},
},
},
}
Expand Down
5 changes: 4 additions & 1 deletion tests/unit/reports/test_base_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,10 @@ def test_convert_datetimes(self):
real_data = pd.DataFrame({'col1': ['2020-01-02', '2021-01-02'], 'col2': ['a', 'b']})
synthetic_data = pd.DataFrame({'col1': ['2022-01-03', '2023-04-05'], 'col2': ['b', 'a']})
metadata = {
'columns': {'col1': {'sdtype': 'datetime'}, 'col2': {'sdtype': 'datetime'}},
'columns': {
'col1': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'},
'col2': {'sdtype': 'datetime'},
},
}

# Run
Expand Down
Loading
Loading