Skip to content
155 changes: 155 additions & 0 deletions sdmetrics/_utils_metadata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
import warnings

import pandas as pd

MODELABLE_SDTYPES = ('numerical', 'datetime', 'categorical', 'boolean')


def _validate_metadata_dict(metadata):
"""Validate the metadata type."""
if not isinstance(metadata, dict):
raise TypeError(
f"Expected a dictionary but received a '{type(metadata).__name__}' instead."
" For SDV metadata objects, please use the 'to_dict' function to convert it"
' to a dictionary.'
)


def _validate_single_table_metadata(metadata):
"""Validate the metadata for a single table."""
_validate_metadata_dict(metadata)
if 'columns' not in metadata:
raise ValueError(
"Single-table metadata must include a 'columns' key that maps column names"
' to their corresponding information.'
)


def _validate_multi_table_metadata(metadata):
"""Validate the metadata for multiple tables."""
_validate_metadata_dict(metadata)
if 'tables' not in metadata:
raise ValueError(
"Multi-table metadata must include a 'tables' key that maps table names"
' to their respective metadata.'
)
for table_name, table_metadata in metadata['tables'].items():
try:
_validate_single_table_metadata(table_metadata)
except ValueError as e:
raise ValueError(f"Error in table '{table_name}': {str(e)}")


def _validate_metadata(metadata):
"""Validate the metadata."""
_validate_metadata_dict(metadata)
if ('columns' not in metadata) and ('tables' not in metadata):
raise ValueError(
"Metadata must include either a 'columns' key for single-table metadata"
" or a 'tables' key for multi-table metadata."
)

if 'tables' in metadata:
_validate_multi_table_metadata(metadata)


def handle_single_and_multi_table(single_table_func):
"""Decorator to handle both single and multi table functions."""

def wrapper(data, metadata):
if isinstance(data, pd.DataFrame):
return single_table_func(data, metadata)

result = {}
for table_name in data:
result[table_name] = single_table_func(data[table_name], metadata['tables'][table_name])

return result

return wrapper


@handle_single_and_multi_table
def _convert_datetime_columns(data, metadata):
"""Convert datetime columns to datetime type."""
columns_missing_datetime_format = []
for column in metadata['columns']:
if metadata['columns'][column]['sdtype'] == 'datetime':
is_datetime = pd.api.types.is_datetime64_any_dtype(data[column])
if not is_datetime:
datetime_format = metadata['columns'][column].get('format')
try:
if datetime_format:
data[column] = pd.to_datetime(data[column], format=datetime_format)
else:
columns_missing_datetime_format.append(column)
data[column] = pd.to_datetime(data[column])
except Exception as e:
raise ValueError(
f"Failed to convert column '{column}' to datetime with the error: {str(e)}"
) from e

if columns_missing_datetime_format:
columns_to_print = "', '".join(columns_missing_datetime_format)
warnings.warn(
f'No `datetime_format` provided in the metadata when trying to convert the columns'
f" '{columns_to_print}' to datetime. The format will be inferred, but it may not"
' be accurate.',
UserWarning,
)

return data


@handle_single_and_multi_table
def _remove_missing_columns_metadata(data, metadata):
"""Remove columns that are not present in the metadata."""
columns_in_metadata = set(metadata['columns'].keys())
columns_in_data = set(data.columns)
columns_to_remove = columns_in_data - columns_in_metadata
extra_metadata_columns = columns_in_metadata - columns_in_data
if columns_to_remove:
columns_to_print = "', '".join(sorted(columns_to_remove))
warnings.warn(
f"The columns ('{columns_to_print}') are not present in the metadata."
'They will not be included for further evaluation.',
UserWarning,
)
elif extra_metadata_columns:
columns_to_print = "', '".join(sorted(extra_metadata_columns))
warnings.warn(
f"The columns ('{columns_to_print}') are in the metadata but they are not "
'present in the data.',
UserWarning,
)

data = data.drop(columns=columns_to_remove)
column_intersection = [column for column in data.columns if column in metadata['columns']]

return data[column_intersection]


@handle_single_and_multi_table
def _remove_non_modelable_columns(data, metadata):
"""Remove columns that are not modelable.

All modelable columns are numerical, datetime, categorical, or boolean sdtypes.
"""
columns_modelable = []
for column in metadata['columns']:
column_sdtype = metadata['columns'][column]['sdtype']
if column_sdtype in MODELABLE_SDTYPES and column in data.columns:
columns_modelable.append(column)

return data[columns_modelable]


def _process_data_with_metadata(data, metadata, keep_modelable_columns_only=False):
"""Process the data according to the metadata."""
_validate_metadata_dict(metadata)
data = _convert_datetime_columns(data, metadata)
data = _remove_missing_columns_metadata(data, metadata)
if keep_modelable_columns_only:
data = _remove_non_modelable_columns(data, metadata)

return data
5 changes: 2 additions & 3 deletions sdmetrics/multi_table/detection/parent_child.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import numpy as np

from sdmetrics._utils_metadata import _validate_multi_table_metadata
from sdmetrics.multi_table.detection.base import DetectionMetric
from sdmetrics.single_table.detection import LogisticDetection, SVCDetection
from sdmetrics.utils import get_columns_from_metadata, nested_attrs_meta
Expand Down Expand Up @@ -37,9 +38,7 @@ class ParentChildDetectionMetric(

@staticmethod
def _extract_foreign_keys(metadata):
if not isinstance(metadata, dict):
metadata = metadata.to_dict()

_validate_multi_table_metadata(metadata)
foreign_keys = []
for child_table, child_meta in metadata['tables'].items():
for child_key, field_meta in get_columns_from_metadata(child_meta).items():
Expand Down
4 changes: 2 additions & 2 deletions sdmetrics/multi_table/multi_single_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np

from sdmetrics import single_table
from sdmetrics._utils_metadata import _validate_multi_table_metadata
from sdmetrics.errors import IncomputableMetricError
from sdmetrics.multi_table.base import MultiTableMetric
from sdmetrics.utils import nested_attrs_meta
Expand Down Expand Up @@ -77,9 +78,8 @@ def _compute(self, real_data, synthetic_data, metadata=None, **kwargs):

if metadata is None:
metadata = {'tables': defaultdict(type(None))}
elif not isinstance(metadata, dict):
metadata = metadata.to_dict()

_validate_multi_table_metadata(metadata)
scores = {}
errors = {}
for table_name, real_table in real_data.items():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
import numpy as np
from scipy.stats import ks_2samp

from sdmetrics._utils_metadata import (
_validate_metadata_dict,
)
from sdmetrics.goal import Goal
from sdmetrics.multi_table.base import MultiTableMetric
from sdmetrics.utils import get_cardinality_distribution
Expand Down Expand Up @@ -53,9 +56,8 @@ def compute_breakdown(cls, real_data, synthetic_data, metadata):
"""
if set(real_data.keys()) != set(synthetic_data.keys()):
raise ValueError('`real_data` and `synthetic_data` must have the same tables.')
if not isinstance(metadata, dict):
metadata = metadata.to_dict()

_validate_metadata_dict(metadata)
score_breakdowns = {}
for rel in metadata.get('relationships', []):
cardinality_real = get_cardinality_distribution(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import numpy as np

from sdmetrics._utils_metadata import _validate_metadata_dict
from sdmetrics.goal import Goal
from sdmetrics.multi_table.base import MultiTableMetric
from sdmetrics.utils import get_cardinality_distribution
Expand Down Expand Up @@ -107,9 +108,8 @@ def compute_breakdown(cls, real_data, synthetic_data, metadata=None, statistic='
raise ValueError('`real_data` and `synthetic_data` must have the same tables.')
if metadata is None:
raise ValueError('`metadata` cannot be ``None``.')
if not isinstance(metadata, dict):
metadata = metadata.to_dict()

_validate_metadata_dict(metadata)
score_breakdowns = {}
for rel in metadata.get('relationships', []):
cardinality_real = get_cardinality_distribution(
Expand Down
21 changes: 2 additions & 19 deletions sdmetrics/reports/base_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import pandas as pd
import tqdm

from sdmetrics._utils_metadata import _validate_metadata
from sdmetrics.reports.utils import convert_datetime_columns
from sdmetrics.visualization import set_plotly_config

Expand Down Expand Up @@ -68,17 +69,6 @@ def _validate_data_format(self, real_data, synthetic_data):
)
raise ValueError(error_message)

def _validate_metadata_format(self, metadata):
"""Validate the metadata."""
if not isinstance(metadata, dict):
raise TypeError('The provided metadata is not a dictionary.')

if 'columns' not in metadata:
raise ValueError(
'Single table reports expect metadata to contain a "columns" key with a mapping'
' from column names to column informations.'
)

def _validate(self, real_data, synthetic_data, metadata):
"""Validate the inputs.

Expand All @@ -91,7 +81,7 @@ def _validate(self, real_data, synthetic_data, metadata):
The metadata of the table.
"""
self._validate_data_format(real_data, synthetic_data)
self._validate_metadata_format(metadata)
_validate_metadata(metadata)
self._validate_metadata_matches_data(real_data, synthetic_data, metadata)

@staticmethod
Expand Down Expand Up @@ -142,13 +132,6 @@ def generate(self, real_data, synthetic_data, metadata, verbose=True):
verbose (bool):
Whether or not to print report summary and progress.
"""
if not isinstance(metadata, dict):
raise TypeError(
f"Expected a dictionary but received a '{type(metadata).__name__}' instead."
" For SDV metadata objects, please use the 'to_dict' function to convert it"
' to a dictionary.'
)

self._validate(real_data, synthetic_data, metadata)
self.convert_datetimes(real_data, synthetic_data, metadata)

Expand Down
16 changes: 0 additions & 16 deletions sdmetrics/reports/multi_table/base_multi_table_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,22 +40,6 @@ def _validate_data_format(self, real_data, synthetic_data):

raise ValueError(error_message)

def _validate_metadata_format(self, metadata):
"""Validate the metadata."""
if not isinstance(metadata, dict):
raise TypeError('The provided metadata is not a dictionary.')

if 'tables' not in metadata:
raise ValueError(
'Multi table reports expect metadata to contain a "tables" key with a mapping'
' from table names to metadata for each table.'
)
for table_name, table_metadata in metadata['tables'].items():
if 'columns' not in table_metadata:
raise ValueError(
f'The metadata for table "{table_name}" is missing a "columns" key.'
)

def _validate_relationships(self, real_data, synthetic_data, metadata):
"""Validate that the relationships are valid."""
for rel in metadata.get('relationships', []):
Expand Down
7 changes: 3 additions & 4 deletions sdmetrics/single_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import pandas as pd

from sdmetrics._utils_metadata import _validate_single_table_metadata
from sdmetrics.base import BaseMetric
from sdmetrics.errors import IncomputableMetricError
from sdmetrics.utils import get_alternate_keys, get_columns_from_metadata, get_type_from_column_meta
Expand Down Expand Up @@ -92,7 +93,7 @@ def _validate_inputs(cls, real_data, synthetic_data, metadata=None):
The real data.
synthetic_data(pandas.DataFrame):
The synthetic data.
metadata (dict or Metadata or None):
metadata (dict):
The metadata, if any.

Returns:
Expand All @@ -108,9 +109,7 @@ def _validate_inputs(cls, real_data, synthetic_data, metadata=None):
raise ValueError('`real_data` and `synthetic_data` must have the same columns')

if metadata is not None:
if not isinstance(metadata, dict):
metadata = metadata.to_dict()

_validate_single_table_metadata(metadata)
fields = get_columns_from_metadata(metadata)
for column in real_data.columns:
if column not in fields:
Expand Down
16 changes: 14 additions & 2 deletions sdmetrics/single_table/data_augmentation/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@

from sdmetrics.goal import Goal
from sdmetrics.single_table.base import SingleTableMetric
from sdmetrics.single_table.data_augmentation.utils import _validate_inputs
from sdmetrics.single_table.data_augmentation.utils import (
_process_data_with_metadata_ml_efficacy_metrics,
_validate_inputs,
)

METRIC_NAME_TO_METHOD = {'recall': recall_score, 'precision': precision_score}

Expand Down Expand Up @@ -104,7 +107,11 @@ def _fit(cls, data, metadata, prediction_column_name):
"""Fit preprocessing parameters."""
discrete_columns = []
datetime_columns = []
for column, column_meta in metadata['columns'].items():
data_columns = data.columns
metadata_columns = metadata['columns'].keys()
common_columns = set(data_columns).intersection(metadata_columns)
for column in sorted(common_columns):
column_meta = metadata['columns'][column]
if (column_meta['sdtype'] in ['categorical', 'boolean']) and (
column != prediction_column_name
):
Expand Down Expand Up @@ -192,6 +199,11 @@ def compute_breakdown(
classifier,
fixed_value,
)
(real_training_data, synthetic_data, real_validation_data) = (
_process_data_with_metadata_ml_efficacy_metrics(
real_training_data, synthetic_data, real_validation_data, metadata
)
)
preprocessed_tables = cls._fit_transform(
real_training_data,
synthetic_data,
Expand Down
Loading
Loading