Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,11 @@ fix-lint:
# TEST TARGETS
.PHONY: test-unit
test-unit: ## run tests quickly with the default Python
python -m pytest --cov=sdgym
invoke unit

.PHONY: test-integration
test-integration: ## run tests quickly with the default Python
invoke integration

.PHONY: test-readme
test-readme: ## run the readme snippets
Expand All @@ -102,7 +106,7 @@ test-readme: ## run the readme snippets
rm -rf tests/readme_test

.PHONY: test
test: test-unit test-readme ## test everything that needs test dependencies
test: test-unit test-integration ## test everything that needs test dependencies

.PHONY: test-devel
test-devel: lint ## test everything that needs development dependencies
Expand Down
11 changes: 8 additions & 3 deletions sdgym/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@

import logging

from sdgym.benchmark import benchmark_single_table, benchmark_single_table_aws
from sdgym.benchmark import (
benchmark_multi_table,
benchmark_single_table,
benchmark_single_table_aws,
)
from sdgym.cli.collect import collect_results
from sdgym.cli.summary import make_summary_spreadsheet
from sdgym.dataset_explorer import DatasetExplorer
Expand All @@ -31,12 +35,13 @@
__all__ = [
'DatasetExplorer',
'ResultsExplorer',
'benchmark_multi_table',
'benchmark_single_table',
'benchmark_single_table_aws',
'collect_results',
'create_synthesizer_variant',
'create_single_table_synthesizer',
'create_multi_table_synthesizer',
'create_single_table_synthesizer',
'create_synthesizer_variant',
'get_available_datasets',
'load_dataset',
'make_summary_spreadsheet',
Expand Down
114 changes: 86 additions & 28 deletions sdgym/_dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,14 @@

import numpy as np
import pandas as pd
from sdv.metadata import Metadata
from sdv.utils import poc

LOGGER = logging.getLogger(__name__)

MAX_NUM_COLUMNS = 10
MAX_NUM_ROWS = 1000


def _parse_numeric_value(value, dataset_name, field_name, target_type=float):
"""Generic parser for numeric values with logging and NaN fallback."""
Expand All @@ -23,6 +28,74 @@ def _parse_numeric_value(value, dataset_name, field_name, target_type=float):
return np.nan


def _filter_columns(columns, mandatory_columns):
"""Given a dictionary of columns and a list of mandatory ones, return a filtered subset."""
all_columns = list(columns)
mandatory_columns = [m_col for m_col in mandatory_columns if m_col in columns]
optional_columns = [opt_col for opt_col in all_columns if opt_col not in mandatory_columns]

if len(mandatory_columns) >= MAX_NUM_COLUMNS:
keep_columns = mandatory_columns
elif len(all_columns) > MAX_NUM_COLUMNS:
keep_count = MAX_NUM_COLUMNS - len(mandatory_columns)
keep_columns = mandatory_columns + optional_columns[:keep_count]
else:
keep_columns = mandatory_columns + optional_columns

return {col: columns[col] for col in keep_columns if col in columns}


def _get_multi_table_dataset_subset(data, metadata_dict):
"""Create a smaller, referentially consistent subset of multi-table data.

This function limits each table to at most 10 columns by keeping all
mandatory columns and, if needed, a subset of the remaining columns, then
trims the underlying DataFrames to match the updated metadata. Finally, it
uses SDV's multi-table utility to sample up to 1,000 rows from
the main table and a consistent subset of rows from all related tables
while preserving referential integrity.

Args:
data (dict):
A dictionary where keys are table names and values are DataFrames
representing tables.
metadata_dict (dict):
Metadata dictionary containing schema information for each table.

Returns:
tuple:
A tuple containing:
- dict: The subset of the input data with reduced columns and rows.
- dict: The updated metadata dictionary reflecting the reduced column sets.
"""
metadata = Metadata.load_from_dict(metadata_dict)
for table_name, table in metadata.tables.items():
table_columns = table.columns
mandatory_columns = list(metadata._get_all_keys(table_name))
subset_column_schema = _filter_columns(
columns=table_columns, mandatory_columns=mandatory_columns
)
metadata_dict['tables'][table_name]['columns'] = subset_column_schema

# Re-load the metadata object that will be used with the `SDV` utility function
metadata = Metadata.load_from_dict(metadata_dict)
largest_table_name = max(data, key=lambda table_name: len(data[table_name]))

# Trim the data to contain only the subset of columns
for table_name, table in metadata.tables.items():
data[table_name] = data[table_name][list(table.columns)]

# Subsample the data mantaining the referential integrity
data = poc.get_random_subset(
data=data,
metadata=metadata,
main_table_name=largest_table_name,
num_rows=MAX_NUM_ROWS,
verbose=False,
)
return data, metadata_dict


def _get_dataset_subset(data, metadata_dict, modality):
"""Limit the size of a dataset for faster evaluation or testing.

Expand All @@ -31,52 +104,37 @@ def _get_dataset_subset(data, metadata_dict, modality):
columns—such as sequence indices and keys in sequential datasets—are always retained.

Args:
data (pd.DataFrame):
data (pd.DataFrame or dict):
The dataset to be reduced.
metadata_dict (dict):
A dictionary containing the dataset's metadata.
A dictionary representing the dataset's metadata.
modality (str):
The dataset modality. Must be one of: ``'single_table'``, ``'sequential'``.
The dataset modality.

Returns:
tuple[pd.DataFrame, dict]:
A tuple containing:
- The reduced dataset as a DataFrame.
- The reduced dataset as a DataFrame or Dictionary.
- The updated metadata dictionary reflecting any removed columns.

Raises:
ValueError:
If the provided modality is ``'multi_table'``.
"""
if modality == 'multi_table':
raise ValueError('limit_dataset_size is not supported for multi-table datasets.')
return _get_multi_table_dataset_subset(data, metadata_dict)

max_rows, max_columns = (1000, 10)
tables = metadata_dict.get('tables', {})
mandatory_columns = []
table_name, table_info = next(iter(tables.items()))

columns = table_info.get('columns', {})
keep_columns = list(columns)
if modality == 'sequential':
seq_index = table_info.get('sequence_index')
seq_key = table_info.get('sequence_key')
mandatory_columns = [col for col in (seq_index, seq_key) if col]

optional_columns = [col for col in columns if col not in mandatory_columns]
seq_index = table_info.get('sequence_index')
seq_key = table_info.get('sequence_key')
mandatory_columns = [column for column in (seq_index, seq_key) if column]
filtered = _filter_columns(columns=columns, mandatory_columns=mandatory_columns)

# If we have too many columns, drop extras but never mandatory ones
if len(columns) > max_columns:
keep_count = max_columns - len(mandatory_columns)
keep_columns = mandatory_columns + optional_columns[:keep_count]
table_info['columns'] = {
column_name: column_definition
for column_name, column_definition in columns.items()
if column_name in keep_columns
}

data = data[list(keep_columns)]
table_info['columns'] = filtered
data = data[list(filtered)]
max_rows = min(MAX_NUM_ROWS, len(data))
data = data.sample(max_rows)

return data, metadata_dict


Expand Down
Loading