diff --git a/.github/workflows/minimum.yml b/.github/workflows/minimum.yml index 1171d63b..b7f6d6b4 100644 --- a/.github/workflows/minimum.yml +++ b/.github/workflows/minimum.yml @@ -12,7 +12,7 @@ concurrency: jobs: minimum: runs-on: ${{ matrix.os }} - timeout-minutes: 30 + timeout-minutes: 45 strategy: matrix: python-version: ['3.9', '3.10', '3.11', '3.12', '3.13'] diff --git a/Makefile b/Makefile index 30a13ece..dce4a909 100644 --- a/Makefile +++ b/Makefile @@ -93,7 +93,11 @@ fix-lint: # TEST TARGETS .PHONY: test-unit test-unit: ## run tests quickly with the default Python - python -m pytest --cov=sdgym + invoke unit + +.PHONY: test-integration +test-integration: ## run tests quickly with the default Python + invoke integration .PHONY: test-readme test-readme: ## run the readme snippets @@ -102,7 +106,7 @@ test-readme: ## run the readme snippets rm -rf tests/readme_test .PHONY: test -test: test-unit test-readme ## test everything that needs test dependencies +test: test-unit test-integration ## test everything that needs test dependencies .PHONY: test-devel test-devel: lint ## test everything that needs development dependencies diff --git a/sdgym/__init__.py b/sdgym/__init__.py index dc2e0fc7..ad23a614 100644 --- a/sdgym/__init__.py +++ b/sdgym/__init__.py @@ -12,7 +12,11 @@ import logging -from sdgym.benchmark import benchmark_single_table, benchmark_single_table_aws +from sdgym.benchmark import ( + benchmark_multi_table, + benchmark_single_table, + benchmark_single_table_aws, +) from sdgym.cli.collect import collect_results from sdgym.cli.summary import make_summary_spreadsheet from sdgym.dataset_explorer import DatasetExplorer @@ -31,12 +35,13 @@ __all__ = [ 'DatasetExplorer', 'ResultsExplorer', + 'benchmark_multi_table', 'benchmark_single_table', 'benchmark_single_table_aws', 'collect_results', - 'create_synthesizer_variant', - 'create_single_table_synthesizer', 'create_multi_table_synthesizer', + 'create_single_table_synthesizer', + 'create_synthesizer_variant', 'get_available_datasets', 'load_dataset', 'make_summary_spreadsheet', diff --git a/sdgym/_dataset_utils.py b/sdgym/_dataset_utils.py index 3e52998e..1bbc454c 100644 --- a/sdgym/_dataset_utils.py +++ b/sdgym/_dataset_utils.py @@ -7,9 +7,14 @@ import numpy as np import pandas as pd +from sdv.metadata import Metadata +from sdv.utils import poc LOGGER = logging.getLogger(__name__) +MAX_NUM_COLUMNS = 10 +MAX_NUM_ROWS = 1000 + def _parse_numeric_value(value, dataset_name, field_name, target_type=float): """Generic parser for numeric values with logging and NaN fallback.""" @@ -23,6 +28,74 @@ def _parse_numeric_value(value, dataset_name, field_name, target_type=float): return np.nan +def _filter_columns(columns, mandatory_columns): + """Given a dictionary of columns and a list of mandatory ones, return a filtered subset.""" + all_columns = list(columns) + mandatory_columns = [m_col for m_col in mandatory_columns if m_col in columns] + optional_columns = [opt_col for opt_col in all_columns if opt_col not in mandatory_columns] + + if len(mandatory_columns) >= MAX_NUM_COLUMNS: + keep_columns = mandatory_columns + elif len(all_columns) > MAX_NUM_COLUMNS: + keep_count = MAX_NUM_COLUMNS - len(mandatory_columns) + keep_columns = mandatory_columns + optional_columns[:keep_count] + else: + keep_columns = mandatory_columns + optional_columns + + return {col: columns[col] for col in keep_columns if col in columns} + + +def _get_multi_table_dataset_subset(data, metadata_dict): + """Create a smaller, referentially consistent subset of multi-table data. + + This function limits each table to at most 10 columns by keeping all + mandatory columns and, if needed, a subset of the remaining columns, then + trims the underlying DataFrames to match the updated metadata. Finally, it + uses SDV's multi-table utility to sample up to 1,000 rows from + the main table and a consistent subset of rows from all related tables + while preserving referential integrity. + + Args: + data (dict): + A dictionary where keys are table names and values are DataFrames + representing tables. + metadata_dict (dict): + Metadata dictionary containing schema information for each table. + + Returns: + tuple: + A tuple containing: + - dict: The subset of the input data with reduced columns and rows. + - dict: The updated metadata dictionary reflecting the reduced column sets. + """ + metadata = Metadata.load_from_dict(metadata_dict) + for table_name, table in metadata.tables.items(): + table_columns = table.columns + mandatory_columns = list(metadata._get_all_keys(table_name)) + subset_column_schema = _filter_columns( + columns=table_columns, mandatory_columns=mandatory_columns + ) + metadata_dict['tables'][table_name]['columns'] = subset_column_schema + + # Re-load the metadata object that will be used with the `SDV` utility function + metadata = Metadata.load_from_dict(metadata_dict) + largest_table_name = max(data, key=lambda table_name: len(data[table_name])) + + # Trim the data to contain only the subset of columns + for table_name, table in metadata.tables.items(): + data[table_name] = data[table_name][list(table.columns)] + + # Subsample the data mantaining the referential integrity + data = poc.get_random_subset( + data=data, + metadata=metadata, + main_table_name=largest_table_name, + num_rows=MAX_NUM_ROWS, + verbose=False, + ) + return data, metadata_dict + + def _get_dataset_subset(data, metadata_dict, modality): """Limit the size of a dataset for faster evaluation or testing. @@ -31,52 +104,37 @@ def _get_dataset_subset(data, metadata_dict, modality): columns—such as sequence indices and keys in sequential datasets—are always retained. Args: - data (pd.DataFrame): + data (pd.DataFrame or dict): The dataset to be reduced. metadata_dict (dict): - A dictionary containing the dataset's metadata. + A dictionary representing the dataset's metadata. modality (str): - The dataset modality. Must be one of: ``'single_table'``, ``'sequential'``. + The dataset modality. Returns: tuple[pd.DataFrame, dict]: A tuple containing: - - The reduced dataset as a DataFrame. + - The reduced dataset as a DataFrame or Dictionary. - The updated metadata dictionary reflecting any removed columns. - - Raises: - ValueError: - If the provided modality is ``'multi_table'``. """ if modality == 'multi_table': - raise ValueError('limit_dataset_size is not supported for multi-table datasets.') + return _get_multi_table_dataset_subset(data, metadata_dict) - max_rows, max_columns = (1000, 10) tables = metadata_dict.get('tables', {}) mandatory_columns = [] table_name, table_info = next(iter(tables.items())) - columns = table_info.get('columns', {}) - keep_columns = list(columns) - if modality == 'sequential': - seq_index = table_info.get('sequence_index') - seq_key = table_info.get('sequence_key') - mandatory_columns = [col for col in (seq_index, seq_key) if col] - optional_columns = [col for col in columns if col not in mandatory_columns] + seq_index = table_info.get('sequence_index') + seq_key = table_info.get('sequence_key') + mandatory_columns = [column for column in (seq_index, seq_key) if column] + filtered = _filter_columns(columns=columns, mandatory_columns=mandatory_columns) - # If we have too many columns, drop extras but never mandatory ones - if len(columns) > max_columns: - keep_count = max_columns - len(mandatory_columns) - keep_columns = mandatory_columns + optional_columns[:keep_count] - table_info['columns'] = { - column_name: column_definition - for column_name, column_definition in columns.items() - if column_name in keep_columns - } - - data = data[list(keep_columns)] + table_info['columns'] = filtered + data = data[list(filtered)] + max_rows = min(MAX_NUM_ROWS, len(data)) data = data.sample(max_rows) + return data, metadata_dict diff --git a/sdgym/benchmark.py b/sdgym/benchmark.py index 64bbbbe2..1824b0b9 100644 --- a/sdgym/benchmark.py +++ b/sdgym/benchmark.py @@ -5,7 +5,6 @@ import math import multiprocessing import os -import pickle import re import textwrap import threading @@ -52,7 +51,7 @@ write_csv, write_file, ) -from sdgym.synthesizers import UniformSynthesizer +from sdgym.synthesizers import MultiTableUniformSynthesizer, UniformSynthesizer from sdgym.synthesizers.base import BaselineSynthesizer from sdgym.utils import ( calculate_score_time, @@ -67,8 +66,13 @@ ) LOGGER = logging.getLogger(__name__) -DEFAULT_SYNTHESIZERS = ['GaussianCopulaSynthesizer', 'CTGANSynthesizer', 'UniformSynthesizer'] -DEFAULT_DATASETS = [ +DEFAULT_SINGLE_TABLE_SYNTHESIZERS = [ + 'GaussianCopulaSynthesizer', + 'CTGANSynthesizer', + 'UniformSynthesizer', +] +DEFAULT_MULTI_TABLE_SYNTHESIZERS = ['MultiTableUniformSynthesizer', 'HMASynthesizer'] +DEFAULT_SINGLE_TABLE_DATASETS = [ 'adult', 'alarm', 'census', @@ -79,6 +83,16 @@ 'intrusion', 'news', ] +DEFAULT_MULTI_TABLE_DATASETS = [ + 'NBA', + 'financial', + 'Student_loan', + 'Biodegradability', + 'fake_hotels', + 'restbase', + 'airbnb-simplified', +] + N_BYTES_IN_MB = 1000 * 1000 EXTERNAL_SYNTHESIZER_TO_LIBRARY = { 'RealTabFormerSynthesizer': 'realtabformer', @@ -92,6 +106,9 @@ 'CopulaGANSynthesizer', 'TVAESynthesizer', ] +SDV_MULTI_TABLE_SYNTHESIZERS = ['HMASynthesizer'] + +SDV_SYNTHESIZERS = SDV_SINGLE_TABLE_SYNTHESIZERS + SDV_MULTI_TABLE_SYNTHESIZERS def _validate_inputs(output_filepath, detailed_results_folder, synthesizers, custom_synthesizers): @@ -106,7 +123,7 @@ def _validate_inputs(output_filepath, detailed_results_folder, synthesizers, cus 'Please provide a folder that does not already exist.' ) - duplicates = get_duplicates(synthesizers) if synthesizers else {} + duplicates = get_duplicates(synthesizers) if synthesizers else set() if custom_synthesizers: duplicates.update(get_duplicates(custom_synthesizers)) if len(duplicates) > 0: @@ -136,6 +153,7 @@ def _get_metainfo_increment(top_folder, s3_client=None): if match: # Extract numeric suffix (e.g. metainfo(3).yaml → 3) or 0 if plain metainfo.yaml increments.append(int(match.group(1)) if match.group(1) else 0) + except Exception: LOGGER.info(first_file_message) return 0 # start with (0) if error @@ -188,7 +206,13 @@ def _setup_output_destination_aws(output_destination, synthesizers, datasets, s3 return paths -def _setup_output_destination(output_destination, synthesizers, datasets, s3_client=None): +def _setup_output_destination( + output_destination, + synthesizers, + datasets, + modality, + s3_client=None, +): """Set up the output destination for the benchmark results. Args: @@ -199,6 +223,10 @@ def _setup_output_destination(output_destination, synthesizers, datasets, s3_cli The list of synthesizers to benchmark. datasets (list): The list of datasets to benchmark. + modality (str): + The dataset modality to load (e.g., 'single_table' or 'multi_table'). + s3_client (boto3.session.Session.client or None): + The s3 client that can be used to read / write to s3. Defaults to ``None``. """ if s3_client: return _setup_output_destination_aws(output_destination, synthesizers, datasets, s3_client) @@ -209,11 +237,12 @@ def _setup_output_destination(output_destination, synthesizers, datasets, s3_cli output_path = Path(output_destination) output_path.mkdir(parents=True, exist_ok=True) today = datetime.today().strftime('%m_%d_%Y') - top_folder = output_path / f'SDGym_results_{today}' + top_folder = output_path / modality / f'SDGym_results_{today}' top_folder.mkdir(parents=True, exist_ok=True) increment = _get_metainfo_increment(top_folder) suffix = f'({increment})' if increment >= 1 else '' paths = defaultdict(dict) + synthetic_data_extension = 'zip' if modality == 'multi_table' else 'csv' for dataset in datasets: dataset_folder = top_folder / f'{dataset}_{today}' dataset_folder.mkdir(parents=True, exist_ok=True) @@ -224,7 +253,9 @@ def _setup_output_destination(output_destination, synthesizers, datasets, s3_cli synth_folder.mkdir(parents=True, exist_ok=True) paths[dataset][final_synth_name] = { 'synthesizer': str(synth_folder / f'{final_synth_name}.pkl'), - 'synthetic_data': str(synth_folder / f'{final_synth_name}_synthetic_data.csv'), + 'synthetic_data': str( + synth_folder / f'{final_synth_name}_synthetic_data.{synthetic_data_extension}' + ), 'benchmark_result': str(synth_folder / f'{final_synth_name}_benchmark_result.csv'), 'metainfo': str(top_folder / f'metainfo{suffix}.yaml'), 'results': str(top_folder / f'results{suffix}.csv'), @@ -247,6 +278,7 @@ def _generate_job_args_list( synthesizers, custom_synthesizers, s3_client, + modality, ): # Get list of synthesizer objects synthesizers = [] if synthesizers is None else synthesizers @@ -260,7 +292,7 @@ def _generate_job_args_list( [] if sdv_datasets is None else get_dataset_paths( - modality='single_table', + modality=modality, datasets=sdv_datasets, aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key_key, @@ -270,11 +302,11 @@ def _generate_job_args_list( [] if additional_datasets_folder is None else get_dataset_paths( - modality='single_table', + modality=modality, bucket=( additional_datasets_folder if is_s3_path(additional_datasets_folder) - else os.path.join(additional_datasets_folder, 'single_table') + else os.path.join(additional_datasets_folder, modality) ), aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key_key, @@ -284,7 +316,7 @@ def _generate_job_args_list( synthesizer_names = [synthesizer['name'] for synthesizer in synthesizers] dataset_names = [dataset.name for dataset in datasets] paths = _setup_output_destination( - output_destination, synthesizer_names, dataset_names, s3_client=s3_client + output_destination, synthesizer_names, dataset_names, modality=modality, s3_client=s3_client ) job_tuples = [] for dataset in datasets: @@ -302,9 +334,7 @@ def _generate_job_args_list( job_args_list = [] for synthesizer, dataset in job_tuples: - data, metadata_dict = load_dataset( - 'single_table', dataset, limit_dataset_size=limit_dataset_size - ) + data, metadata_dict = load_dataset(modality, dataset, limit_dataset_size=limit_dataset_size) path = paths.get(dataset.name, {}).get(synthesizer['name'], None) args = ( synthesizer, @@ -317,7 +347,7 @@ def _generate_job_args_list( compute_diagnostic_score, compute_privacy_score, dataset.name, - 'single_table', + modality, path, ) job_args_list.append(args) @@ -325,7 +355,14 @@ def _generate_job_args_list( return job_args_list -def _synthesize(synthesizer_dict, real_data, metadata, synthesizer_path=None, result_writer=None): +def _synthesize( + synthesizer_dict, + real_data, + metadata, + modality, + synthesizer_path=None, + result_writer=None, +): synthesizer = synthesizer_dict['synthesizer'] if isinstance(synthesizer, type): assert issubclass(synthesizer, BaselineSynthesizer), ( @@ -340,22 +377,33 @@ def _synthesize(synthesizer_dict, real_data, metadata, synthesizer_path=None, re get_synthesizer = synthesizer.get_trained_synthesizer sample_from_synthesizer = synthesizer.sample_from_synthesizer data = real_data.copy() - num_samples = len(data) tracemalloc.start() now = get_utc_now() synthesizer_obj = get_synthesizer(data, metadata) - synthesizer_size = len(pickle.dumps(synthesizer_obj)) / N_BYTES_IN_MB + synthesizer_size = len(cloudpickle.dumps(synthesizer_obj)) / N_BYTES_IN_MB train_now = get_utc_now() - synthetic_data = sample_from_synthesizer(synthesizer_obj, num_samples) - sample_now = get_utc_now() + if modality == 'multi_table': + synthetic_data = sample_from_synthesizer(synthesizer_obj, 1.0) + else: + synthetic_data = sample_from_synthesizer(synthesizer_obj, n_samples=len(data)) + + sample_now = get_utc_now() peak_memory = tracemalloc.get_traced_memory()[1] / N_BYTES_IN_MB tracemalloc.stop() tracemalloc.clear_traces() + if synthesizer_path is not None and result_writer is not None: - result_writer.write_dataframe(synthetic_data, synthesizer_path['synthetic_data']) - result_writer.write_pickle(synthesizer_obj, synthesizer_path['synthesizer']) + internal_synthesizer = getattr(synthesizer_obj, '_internal_synthesizer', synthesizer_obj) + result_writer.write_pickle(internal_synthesizer, synthesizer_path['synthesizer']) + if modality == 'multi_table': + result_writer.write_zipped_dataframes( + synthetic_data, synthesizer_path['synthetic_data'] + ) + + else: + result_writer.write_dataframe(synthetic_data, synthesizer_path['synthetic_data']) return synthetic_data, train_now - now, sample_now - train_now, synthesizer_size, peak_memory @@ -373,9 +421,13 @@ def _compute_scores( dataset_name, ): metrics = metrics or [] - sdmetrics_metadata = convert_metadata_to_sdmetrics(metadata) + if modality == 'single_table': + sdmetrics_metadata = convert_metadata_to_sdmetrics(metadata) + else: + sdmetrics_metadata = metadata + if len(metrics) > 0: - metrics, metric_kwargs = get_metrics(metrics, modality='single-table') + metrics, metric_kwargs = get_metrics(metrics, modality=modality.replace('_', '-')) scores = [] output['scores'] = scores for metric_name, metric in metrics.items(): @@ -484,9 +536,10 @@ def _score( # To be deleted if there is no error output['error'] = 'Synthesizer Timeout' synthetic_data, train_time, sample_time, synthesizer_size, peak_memory = _synthesize( - synthesizer, - data.copy(), - metadata, + synthesizer_dict=synthesizer, + real_data=data.copy(), + metadata=metadata, + modality=modality, synthesizer_path=synthesizer_path, result_writer=result_writer, ) @@ -799,6 +852,7 @@ def _run_jobs(multi_processing_config, job_args_list, show_progress, result_writ job_args_list = [job_args + (result_writer,) for job_args in job_args_list] if workers in (0, 1): scores = map(_run_job, job_args_list) + elif workers != 'dask': pool = concurrent.futures.ProcessPoolExecutor(workers) scores = pool.map(_run_job, job_args_list) @@ -1051,7 +1105,7 @@ def _validate_output_destination(output_destination, aws_keys=None): ) -def _write_metainfo_file(synthesizers, job_args_list, result_writer=None): +def _write_metainfo_file(synthesizers, job_args_list, modality, result_writer=None): jobs = [[job[-3], job[0]['name']] for job in job_args_list] if not job_args_list or not job_args_list[0][-1]: return @@ -1068,17 +1122,20 @@ def _write_metainfo_file(synthesizers, job_args_list, result_writer=None): date_str = date_match.group(1) metadata = { 'run_id': f'run_{date_str}_{increment}', + 'modality': modality, 'starting_date': datetime.today().strftime('%m_%d_%Y %H:%M:%S'), 'completed_date': None, 'sdgym_version': version('sdgym'), 'jobs': jobs, } + for synthesizer in synthesizers: - if synthesizer not in SDV_SINGLE_TABLE_SYNTHESIZERS: + if synthesizer not in SDV_SYNTHESIZERS: ext_lib = EXTERNAL_SYNTHESIZER_TO_LIBRARY.get(synthesizer) if ext_lib: library_version = version(ext_lib) metadata[f'{ext_lib}_version'] = library_version + elif 'sdv' not in metadata.keys(): metadata['sdv_version'] = version('sdv') @@ -1099,6 +1156,16 @@ def _ensure_uniform_included(synthesizers): synthesizers.append('UniformSynthesizer') +def _ensure_multi_table_uniform_is_included(synthesizers): + uniform_not_included = bool( + MultiTableUniformSynthesizer not in synthesizers + and MultiTableUniformSynthesizer.__name__ not in synthesizers + ) + if uniform_not_included: + LOGGER.info('Adding MultiTableUniformSynthesizer to the list of synthesizers.') + synthesizers.append('MultiTableUniformSynthesizer') + + def _fill_adjusted_scores_with_none(scores): """Fill adjusted total time and quality score with NaN values.""" scores['Adjusted_Total_Time'] = None @@ -1157,9 +1224,9 @@ def _add_adjusted_scores(scores, timeout): def benchmark_single_table( - synthesizers=DEFAULT_SYNTHESIZERS, + synthesizers=DEFAULT_SINGLE_TABLE_SYNTHESIZERS, custom_synthesizers=None, - sdv_datasets=DEFAULT_DATASETS, + sdv_datasets=DEFAULT_SINGLE_TABLE_DATASETS, additional_datasets_folder=None, limit_dataset_size=False, compute_quality_score=True, @@ -1273,6 +1340,7 @@ def benchmark_single_table( _create_instance_on_ec2(script_content) else: raise ValueError('In order to run on EC2, please provide an S3 folder output.') + return None _validate_inputs(output_filepath, detailed_results_folder, synthesizers, custom_synthesizers) @@ -1291,9 +1359,15 @@ def benchmark_single_table( synthesizers, custom_synthesizers, s3_client=None, + modality='single_table', ) - _write_metainfo_file(synthesizers, job_args_list, result_writer) + _write_metainfo_file( + synthesizers=synthesizers, + job_args_list=job_args_list, + modality='single_table', + result_writer=result_writer, + ) if job_args_list: scores = _run_jobs(multi_processing_config, job_args_list, show_progress, result_writer) @@ -1373,7 +1447,7 @@ def _store_job_args_in_s3(output_destination, job_args_list, s3_client): job_args_key = f'job_args_list_{metainfo}.pkl' job_args_key = f'{path}{job_args_key}' if path else job_args_key - serialized_data = pickle.dumps(job_args_list) + serialized_data = cloudpickle.dumps(job_args_list) s3_client.put_object(Bucket=bucket_name, Key=job_args_key, Body=serialized_data) return bucket_name, job_args_key @@ -1384,7 +1458,7 @@ def _get_s3_script_content( ): return f""" import boto3 -import pickle +import cloudpickle from sdgym.benchmark import _run_jobs, _write_metainfo_file, _update_metainfo_file from io import StringIO from sdgym.result_writer import S3ResultsWriter @@ -1396,9 +1470,9 @@ def _get_s3_script_content( region_name='{region_name}' ) response = s3_client.get_object(Bucket='{bucket_name}', Key='{job_args_key}') -job_args_list = pickle.loads(response['Body'].read()) +job_args_list = cloudpickle.loads(response['Body'].read()) result_writer = S3ResultsWriter(s3_client=s3_client) -_write_metainfo_file({synthesizers}, job_args_list, result_writer) +_write_metainfo_file({synthesizers}, job_args_list, 'single_table', result_writer) scores = _run_jobs(None, job_args_list, False, result_writer=result_writer) metainfo_filename = job_args_list[0][-1]['metainfo'] _update_metainfo_file(metainfo_filename, result_writer) @@ -1509,8 +1583,8 @@ def benchmark_single_table_aws( output_destination, aws_access_key_id=None, aws_secret_access_key=None, - synthesizers=DEFAULT_SYNTHESIZERS, - sdv_datasets=DEFAULT_DATASETS, + synthesizers=DEFAULT_SINGLE_TABLE_SYNTHESIZERS, + sdv_datasets=DEFAULT_SINGLE_TABLE_DATASETS, additional_datasets_folder=None, limit_dataset_size=False, compute_quality_score=True, @@ -1591,6 +1665,7 @@ def benchmark_single_table_aws( detailed_results_folder=None, custom_synthesizers=None, s3_client=s3_client, + modality='single_table', ) if not job_args_list: return _get_empty_dataframe( @@ -1608,3 +1683,125 @@ def benchmark_single_table_aws( aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key, ) + + +def benchmark_multi_table( + synthesizers=DEFAULT_MULTI_TABLE_SYNTHESIZERS, + custom_synthesizers=None, + sdv_datasets=DEFAULT_MULTI_TABLE_DATASETS, + additional_datasets_folder=None, + limit_dataset_size=False, + compute_quality_score=True, + compute_diagnostic_score=True, + timeout=None, + output_destination=None, + show_progress=False, +): + """Run the SDGym benchmark on multi-table datasets. + + Args: + synthesizers (list[string]): + The synthesizer(s) to evaluate. Defaults to ``HMASynthesizer`` and + ``MultiTableUniformSynthesizer``. + custom_synthesizers (list[class] or ``None``): + A list of custom synthesizer classes to use. These can be completely custom or + they can be synthesizer variants (the output from ``create_single_table_synthesizer`` + or ``create_synthesizer_variant``). Defaults to ``None``. + sdv_datasets (list[str] or ``None``): + Names of the SDV demo datasets to use for the benchmark. Defaults to + ``'NBA', 'financial', 'Student_loan', 'Biodegradability', 'fake_hotels', + 'restbase', 'airbnb-simplified'``. Use ``None`` to disable using any sdv datasets. + additional_datasets_folder (str or ``None``): + The path to a folder (local or an S3 bucket). Datasets found in this folder are + run in addition to the SDV datasets. If ``None``, no additional datasets are used. + limit_dataset_size (bool): + Limit the dataset to 10 columns per table (not including primary/foreign keys). + The overall dataset is subsampled with referential integrity. + For the main table, select the table with the larges # of rows; and + for num rows, set it to 1000. + compute_quality_score (bool): + Whether or not to evaluate an overall quality score. Defaults to ``True``. + compute_diagnostic_score (bool): + Whether or not to evaluate an overall diagnostic score. Defaults to ``True``. + timeout (int or ``None``): + The maximum number of seconds to wait for synthetic data creation. If ``None``, no + timeout is enforced. + output_destination (str or ``None``): + The path to the output directory where results will be saved. If ``None``, no + output is saved. The results are saved with the following structure: + output_destination/ + / + run_.yaml + SDGym_results_/ + results.csv + _/ + meta.yaml + / + synthesizer.pkl + synthetic_data.csv + + show_progress (bool): + Whether to use tqdm to keep track of the progress. Defaults to ``False``. + + Returns: + pandas.DataFrame: + A table containing one row per synthesizer + dataset + metric. + """ + _validate_output_destination(output_destination) + if not synthesizers: + synthesizers = [] + + _ensure_multi_table_uniform_is_included(synthesizers) + result_writer = LocalResultsWriter() + + _validate_inputs( + output_filepath=None, + detailed_results_folder=None, + synthesizers=synthesizers, + custom_synthesizers=custom_synthesizers, + ) + job_args_list = _generate_job_args_list( + limit_dataset_size=limit_dataset_size, + sdv_datasets=sdv_datasets, + additional_datasets_folder=additional_datasets_folder, + sdmetrics=None, + detailed_results_folder=None, + timeout=timeout, + output_destination=output_destination, + compute_quality_score=compute_quality_score, + compute_diagnostic_score=compute_diagnostic_score, + compute_privacy_score=None, + synthesizers=synthesizers, + custom_synthesizers=custom_synthesizers, + s3_client=None, + modality='multi_table', + ) + + _write_metainfo_file( + synthesizers=synthesizers, + job_args_list=job_args_list, + modality='multi_table', + result_writer=result_writer, + ) + if job_args_list: + scores = _run_jobs( + multi_processing_config=None, + job_args_list=job_args_list, + show_progress=show_progress, + result_writer=result_writer, + ) + + # If no synthesizers/datasets are passed, return an empty dataframe + else: + scores = _get_empty_dataframe( + compute_diagnostic_score=compute_diagnostic_score, + compute_quality_score=compute_quality_score, + compute_privacy_score=None, + sdmetrics=None, + ) + + if output_destination and job_args_list: + metainfo_filename = job_args_list[0][-1]['metainfo'] + _update_metainfo_file(metainfo_filename, result_writer) + + return scores diff --git a/sdgym/dataset_explorer.py b/sdgym/dataset_explorer.py index b6e5da15..cfaad518 100644 --- a/sdgym/dataset_explorer.py +++ b/sdgym/dataset_explorer.py @@ -186,7 +186,7 @@ def _load_and_summarize_datasets(self, modality): Args: modality (str): - The dataset modality to load (e.g., 'single-table' or 'multi-table'). + The dataset modality to load (e.g., 'single_table' or 'multi_table'). Returns: list[dict]: diff --git a/sdgym/result_explorer/result_explorer.py b/sdgym/result_explorer/result_explorer.py index c7c18833..d3fdc7b7 100644 --- a/sdgym/result_explorer/result_explorer.py +++ b/sdgym/result_explorer/result_explorer.py @@ -2,7 +2,7 @@ import os -from sdgym.benchmark import DEFAULT_DATASETS +from sdgym.benchmark import DEFAULT_SINGLE_TABLE_DATASETS from sdgym.datasets import load_dataset from sdgym.result_explorer.result_handler import LocalResultsHandler, S3ResultsHandler from sdgym.s3 import _get_s3_client, is_s3_path @@ -62,7 +62,7 @@ def load_synthetic_data(self, results_folder_name, dataset_name, synthesizer_nam def load_real_data(self, dataset_name): """Load the real data for a given dataset.""" - if dataset_name not in DEFAULT_DATASETS: + if dataset_name not in DEFAULT_SINGLE_TABLE_DATASETS: raise ValueError( f"Dataset '{dataset_name}' is not a SDGym dataset. " 'Please provide a valid dataset name.' diff --git a/sdgym/result_explorer/result_handler.py b/sdgym/result_explorer/result_handler.py index b0a3ea6f..84f22f3d 100644 --- a/sdgym/result_explorer/result_handler.py +++ b/sdgym/result_explorer/result_handler.py @@ -3,10 +3,10 @@ import io import operator import os -import pickle from abc import ABC, abstractmethod from datetime import datetime +import cloudpickle import pandas as pd import yaml from botocore.exceptions import ClientError @@ -250,7 +250,7 @@ def get_file_path(self, path_parts, end_filename): def load_synthesizer(self, file_path): """Load a synthesizer from a pickle file.""" with open(os.path.join(self.base_path, file_path), 'rb') as f: - return pickle.load(f) + return cloudpickle.load(f) def load_synthetic_data(self, file_path): """Load synthetic data from a CSV file.""" @@ -361,7 +361,7 @@ def load_synthesizer(self, file_path): response = self.s3_client.get_object( Bucket=self.bucket_name, Key=f'{self.prefix}{file_path}' ) - return pickle.loads(response['Body'].read()) + return cloudpickle.loads(response['Body'].read()) def load_synthetic_data(self, file_path): """Load synthetic data from S3.""" diff --git a/sdgym/result_writer.py b/sdgym/result_writer.py index 3101384a..f4d025d9 100644 --- a/sdgym/result_writer.py +++ b/sdgym/result_writer.py @@ -1,10 +1,11 @@ """Results writer for SDGym benchmark.""" import io -import pickle +import zipfile from abc import ABC, abstractmethod from pathlib import Path +import cloudpickle import pandas as pd import plotly.graph_objects as go import yaml @@ -35,6 +36,14 @@ def write_yaml(self, data, file_path, append=False): class LocalResultsWriter: """Local results writer for saving results to the local filesystem.""" + def write_zipped_dataframes(self, data, file_path, index=False): + """Write a dictoinary of dataframes to a ZIP file.""" + with zipfile.ZipFile(file_path, mode='w', compression=zipfile.ZIP_DEFLATED) as zf: + for table_name, table in data.items(): + buf = io.StringIO() + table.to_csv(buf, index=index) + zf.writestr(f'{table_name}.csv', buf.getvalue()) + def write_dataframe(self, data, file_path, append=False, index=False): """Write a DataFrame to a CSV file.""" file_path = Path(file_path) @@ -82,7 +91,7 @@ def write_xlsx(self, data, file_path, index=False): def write_pickle(self, obj, file_path): """Write a Python object to a pickle file.""" with open(file_path, 'wb') as f: - pickle.dump(obj, f) + cloudpickle.dump(obj, f) def write_yaml(self, data, file_path, append=False): """Write data to a YAML file.""" @@ -126,7 +135,7 @@ def write_pickle(self, obj, file_path): """Write a Python object to S3 as a pickle file.""" bucket, key = parse_s3_path(file_path) buffer = io.BytesIO() - pickle.dump(obj, buffer) + cloudpickle.dump(obj, buffer) buffer.seek(0) self.s3_client.put_object(Body=buffer.read(), Bucket=bucket, Key=key) diff --git a/sdgym/synthesizers/__init__.py b/sdgym/synthesizers/__init__.py index c7f44b8b..67368dc7 100644 --- a/sdgym/synthesizers/__init__.py +++ b/sdgym/synthesizers/__init__.py @@ -8,7 +8,7 @@ from sdgym.synthesizers.identity import DataIdentity from sdgym.synthesizers.column import ColumnSynthesizer from sdgym.synthesizers.realtabformer import RealTabFormerSynthesizer -from sdgym.synthesizers.uniform import UniformSynthesizer +from sdgym.synthesizers.uniform import UniformSynthesizer, MultiTableUniformSynthesizer from sdgym.synthesizers.utils import ( get_available_single_table_synthesizers, get_available_multi_table_synthesizers, @@ -26,6 +26,7 @@ 'create_synthesizer_variant', 'get_available_single_table_synthesizers', 'get_available_multi_table_synthesizers', + 'MultiTableUniformSynthesizer', ] for sdv_name in _get_all_sdv_synthesizers(): diff --git a/sdgym/synthesizers/base.py b/sdgym/synthesizers/base.py index 01c1c19a..e9d7950f 100644 --- a/sdgym/synthesizers/base.py +++ b/sdgym/synthesizers/base.py @@ -9,11 +9,23 @@ LOGGER = logging.getLogger(__name__) +def _is_valid_modality(modality): + return modality in ('single_table', 'multi_table') + + +def _validate_modality(modality): + if not _is_valid_modality(modality): + raise ValueError( + f"Modality '{modality}' is not valid. Must be either 'single_table' or 'multi_table'." + ) + + class BaselineSynthesizer(abc.ABC): """Base class for all the ``SDGym`` baselines.""" _MODEL_KWARGS = {} _NATIVELY_SUPPORTED = True + _MODALITY_FLAG = None @classmethod def get_subclasses(cls, include_parents=False): @@ -34,15 +46,18 @@ def get_subclasses(cls, include_parents=False): return subclasses @classmethod - def _get_supported_synthesizers(cls): + def _get_supported_synthesizers(cls, modality): """Get the natively supported synthesizer class names.""" - subclasses = cls.get_subclasses(include_parents=True) - synthesizers = set() - for name, subclass in subclasses.items(): - if subclass._NATIVELY_SUPPORTED: - synthesizers.add(name) - - return sorted(synthesizers) + _validate_modality(modality) + return sorted({ + name + for name, subclass in cls.get_subclasses(include_parents=True).items() + if ( + name != 'MultiTableBaselineSynthesizer' + and subclass._NATIVELY_SUPPORTED + and subclass._MODALITY_FLAG == modality + ) + }) @classmethod def get_baselines(cls): @@ -55,6 +70,35 @@ def get_baselines(cls): return synthesizers + def _fit(self, data, metadata): + """Fit the synthesizer to the data. + + Args: + data (pandas.DataFrame): + The data to fit the synthesizer to. + metadata (sdv.metadata.Metadata): + The metadata describing the data. + """ + raise NotImplementedError() + + @classmethod + def _get_trained_synthesizer(cls, data, metadata): + """Train a synthesizer on the provided data and metadata. + + Args: + data (pd.DataFrame or dict): + The data to train on. + metadata (sdv.metadata.Metadata): + The metadata + + Returns: + A synthesizer object + """ + synthesizer = cls() + synthesizer._fit(data, metadata) + + return synthesizer + def get_trained_synthesizer(self, data, metadata): """Get a synthesizer that has been trained on the provided data and metadata. @@ -90,3 +134,24 @@ def sample_from_synthesizer(self, synthesizer, n_samples): should be a dict mapping table name to DataFrame. """ return self._sample_from_synthesizer(synthesizer, n_samples) + + +class MultiTableBaselineSynthesizer(BaselineSynthesizer): + """Base class for all multi-table synthesizers.""" + + _MODALITY_FLAG = 'multi_table' + + def sample_from_synthesizer(self, synthesizer, scale=1.0): + """Sample data from the provided synthesizer. + + Args: + synthesizer (obj): + The synthesizer object to sample data from. + scale (float): + The scale of data to sample. Defaults to 1.0. + + Returns: + dict: + The sampled data. A dict mapping table name to DataFrame. + """ + return self._sample_from_synthesizer(synthesizer, scale) diff --git a/sdgym/synthesizers/column.py b/sdgym/synthesizers/column.py index 69233283..94107f69 100644 --- a/sdgym/synthesizers/column.py +++ b/sdgym/synthesizers/column.py @@ -19,9 +19,11 @@ class ColumnSynthesizer(BaselineSynthesizer): Continuous columns are learned and sampled using a GMM. """ - def _get_trained_synthesizer(self, real_data, metadata): + _MODALITY_FLAG = 'single_table' + + def _fit(self, data, metadata): hyper_transformer = HyperTransformer() - hyper_transformer.detect_initial_config(real_data) + hyper_transformer.detect_initial_config(data) supported_sdtypes = hyper_transformer._get_supported_sdtypes() config = {} if isinstance(metadata, Metadata): @@ -46,14 +48,14 @@ def _get_trained_synthesizer(self, real_data, metadata): # This is done to match the behavior of the synthesizer for SDGym <= 0.6.0 columns_to_remove = [ - column_name for column_name, data in real_data.items() if data.dtype.kind in {'O', 'i'} + column_name for column_name, data in data.items() if data.dtype.kind in {'O', 'i'} ] hyper_transformer.remove_transformers(columns_to_remove) - hyper_transformer.fit(real_data) - transformed = hyper_transformer.transform(real_data) + hyper_transformer.fit(data) + transformed = hyper_transformer.transform(data) - self.length = len(real_data) + self.length = len(data) gm_models = {} for name, column in transformed.items(): kind = column.dtype.kind @@ -63,18 +65,22 @@ def _get_trained_synthesizer(self, real_data, metadata): model.fit(column.to_numpy().reshape(-1, 1)) gm_models[name] = model - return (hyper_transformer, transformed, gm_models) + self.hyper_transformer = hyper_transformer + self.transformed_data = transformed + self.gm_models = gm_models def _sample_from_synthesizer(self, synthesizer, n_samples): - hyper_transformer, transformed, gm_models = synthesizer + hyper_transformer = synthesizer.hyper_transformer + transformed = synthesizer.transformed_data + gm_models = synthesizer.gm_models sampled = pd.DataFrame() for name, column in transformed.items(): kind = column.dtype.kind if kind == 'O': - values = column.sample(self.length, replace=True).to_numpy() + values = column.sample(n_samples, replace=True).to_numpy() else: model = gm_models.get(name) - values = model.sample(self.length)[0].ravel().clip(column.min(), column.max()) + values = model.sample(n_samples)[0].ravel().clip(column.min(), column.max()) sampled[name] = values diff --git a/sdgym/synthesizers/generate.py b/sdgym/synthesizers/generate.py index 5535dd01..d341d03b 100644 --- a/sdgym/synthesizers/generate.py +++ b/sdgym/synthesizers/generate.py @@ -1,6 +1,10 @@ """Helpers to create SDGym synthesizer variants.""" -from sdgym.synthesizers.base import BaselineSynthesizer +from sdgym.synthesizers.base import ( + BaselineSynthesizer, + MultiTableBaselineSynthesizer, + _validate_modality, +) from sdgym.synthesizers.utils import _get_supported_synthesizers @@ -36,7 +40,7 @@ def create_synthesizer_variant(display_name, synthesizer_class, synthesizer_para return NewSynthesizer -def _create_synthesizer_class(display_name, get_trained_fn, sample_fn, sample_arg_name): +def _create_synthesizer_class(display_name, get_trained_fn, sample_fn, modality): """Create a synthesizer class. Args: @@ -47,36 +51,39 @@ def _create_synthesizer_class(display_name, get_trained_fn, sample_fn, sample_ar A function to generate and train a synthesizer, given the real data and metadata. sample_from_synthesizer (callable): A function to sample from the given synthesizer. - sample_arg_name (str): - The name of the argument used to specify the number of samples to generate. - Either 'num_samples' for single-table synthesizers, or 'scale' for multi-table - synthesizers. + modality (str): + The modality of the synthesizer. Either 'single_table' or 'multi_table'. Returns: class: The synthesizer class. """ + _validate_modality(modality) class_name = f'Custom:{display_name}' def get_trained_synthesizer(self, data, metadata): return get_trained_fn(data, metadata) - if sample_arg_name == 'num_samples': + if modality == 'multi_table': - def sample_from_synthesizer(self, synthesizer, num_samples): - return sample_fn(synthesizer, num_samples) + def sample_from_synthesizer(self, synthesizer, scale=1.0): + return sample_fn(synthesizer, scale) + base_class = MultiTableBaselineSynthesizer else: - def sample_from_synthesizer(self, synthesizer, scale): - return sample_fn(synthesizer, scale) + def sample_from_synthesizer(self, synthesizer, n_samples): + return sample_fn(synthesizer, n_samples) + + base_class = BaselineSynthesizer CustomSynthesizer = type( class_name, - (BaselineSynthesizer,), + (base_class,), { '__module__': __name__, '_NATIVELY_SUPPORTED': False, + '_MODALITY_FLAG': modality, 'get_trained_synthesizer': get_trained_synthesizer, 'sample_from_synthesizer': sample_from_synthesizer, }, @@ -94,7 +101,7 @@ def create_single_table_synthesizer( display_name, get_trained_synthesizer_fn, sample_from_synthesizer_fn, - sample_arg_name='num_samples', + modality='single_table', ) @@ -106,5 +113,5 @@ def create_multi_table_synthesizer( display_name, get_trained_synthesizer_fn, sample_from_synthesizer_fn, - sample_arg_name='scale', + modality='multi_table', ) diff --git a/sdgym/synthesizers/identity.py b/sdgym/synthesizers/identity.py index d63b5956..1c41031a 100644 --- a/sdgym/synthesizers/identity.py +++ b/sdgym/synthesizers/identity.py @@ -11,24 +11,21 @@ class DataIdentity(BaselineSynthesizer): Returns the same exact data that is used to fit it. """ + _MODALITY_FLAG = 'single_table' + def __init__(self): self._data = None - def get_trained_synthesizer(self, data, metadata): - """Get a synthesizer that has been trained on the provided data and metadata. + def _fit(self, data, metadata): + """Fit the synthesizer to the data. Args: data (pandas.DataFrame): - The data to train on. + The data to fit the synthesizer to. metadata (dict): The metadata dictionary. - - Returns: - obj: - The synthesizer object. """ self._data = data - return None def sample_from_synthesizer(self, synthesizer, n_samples): """Sample data from the provided synthesizer. @@ -44,4 +41,4 @@ def sample_from_synthesizer(self, synthesizer, n_samples): The sampled data. If single-table, should be a DataFrame. If multi-table, should be a dict mapping table name to DataFrame. """ - return copy.deepcopy(self._data) + return copy.deepcopy(synthesizer._data) diff --git a/sdgym/synthesizers/realtabformer.py b/sdgym/synthesizers/realtabformer.py index 8b46bda5..92d8dc68 100644 --- a/sdgym/synthesizers/realtabformer.py +++ b/sdgym/synthesizers/realtabformer.py @@ -24,8 +24,10 @@ class RealTabFormerSynthesizer(BaselineSynthesizer): LOGGER = logging.getLogger(__name__) _MODEL_KWARGS = None + _MODALITY_FLAG = 'single_table' - def _get_trained_synthesizer(self, data, metadata): + def _fit(self, data, metadata): + """Fit the REaLTabFormer model to the data.""" try: from realtabformer import REaLTabFormer except Exception as exception: @@ -39,8 +41,8 @@ def _get_trained_synthesizer(self, data, metadata): model = REaLTabFormer(model_type='tabular', **model_kwargs) model.fit(data) - return model + self._internal_synthesizer = model def _sample_from_synthesizer(self, synthesizer, n_sample): """Sample synthetic data with specified sample count.""" - return synthesizer.sample(n_sample) + return synthesizer._internal_synthesizer.sample(n_sample) diff --git a/sdgym/synthesizers/sdv.py b/sdgym/synthesizers/sdv.py index 90e4e3a6..d6f90b1f 100644 --- a/sdgym/synthesizers/sdv.py +++ b/sdgym/synthesizers/sdv.py @@ -6,7 +6,11 @@ from sdv import multi_table, single_table -from sdgym.synthesizers.base import BaselineSynthesizer +from sdgym.synthesizers.base import ( + BaselineSynthesizer, + MultiTableBaselineSynthesizer, + _validate_modality, +) LOGGER = logging.getLogger(__name__) UNSUPPORTED_SDV_SYNTHESIZERS = ['DayZSynthesizer'] @@ -15,11 +19,7 @@ 'multi_table': multi_table, } - -def _validate_modality(modality): - """Validate that the modality is correct.""" - if modality not in ['single_table', 'multi_table']: - raise ValueError("`modality` must be one of 'single_table' or 'multi_table'.") +MODEL_KWARGS = {'HMASynthesizer': {'verbose': False}} def _get_sdv_synthesizers(modality): @@ -39,20 +39,20 @@ def _get_all_sdv_synthesizers(): return sorted(synthesizers) -def _get_trained_synthesizer(self, data, metadata): +def _fit(self, data, metadata): LOGGER.info('Fitting %s', self.__class__.__name__) - sdv_class = getattr(import_module(f'sdv.{self.modality}'), self.SDV_NAME) + sdv_class = getattr(import_module(f'sdv.{self._MODALITY_FLAG}'), self.SDV_NAME) synthesizer = sdv_class(metadata=metadata, **self._MODEL_KWARGS) synthesizer.fit(data) - return synthesizer + self._internal_synthesizer = synthesizer def _sample_from_synthesizer(self, synthesizer, sample_arg): LOGGER.info('Sampling %s', self.__class__.__name__) - if self.modality == 'multi_table': - return synthesizer.sample(scale=sample_arg) + if self._MODALITY_FLAG == 'multi_table': + return synthesizer._internal_synthesizer.sample(scale=sample_arg) - return synthesizer.sample(num_rows=sample_arg) + return synthesizer._internal_synthesizer.sample(num_rows=sample_arg) def _retrieve_sdv_class(sdv_name): @@ -82,15 +82,17 @@ def _create_sdv_class(sdv_name): """Create a SDV synthesizer class dynamically.""" current_module = sys.modules[__name__] modality = _get_modality(sdv_name) + base_class = MultiTableBaselineSynthesizer if modality == 'multi_table' else BaselineSynthesizer + model_kwargs = MODEL_KWARGS.get(sdv_name, {}) synthesizer_class = type( sdv_name, - (BaselineSynthesizer,), + (base_class,), { '__module__': __name__, 'SDV_NAME': sdv_name, - 'modality': modality, - '_MODEL_KWARGS': {}, - '_get_trained_synthesizer': _get_trained_synthesizer, + '_MODALITY_FLAG': modality, + '_MODEL_KWARGS': model_kwargs, + '_fit': _fit, '_sample_from_synthesizer': _sample_from_synthesizer, }, ) diff --git a/sdgym/synthesizers/uniform.py b/sdgym/synthesizers/uniform.py index 57713839..f562c1bf 100644 --- a/sdgym/synthesizers/uniform.py +++ b/sdgym/synthesizers/uniform.py @@ -7,7 +7,7 @@ import pandas as pd from rdt.hyper_transformer import HyperTransformer -from sdgym.synthesizers.base import BaselineSynthesizer +from sdgym.synthesizers.base import BaselineSynthesizer, MultiTableBaselineSynthesizer LOGGER = logging.getLogger(__name__) @@ -15,9 +15,24 @@ class UniformSynthesizer(BaselineSynthesizer): """Synthesizer that samples each column using a Uniform distribution.""" - def _get_trained_synthesizer(self, real_data, metadata): + _MODALITY_FLAG = 'single_table' + + def __init__(self): + super().__init__() + self.hyper_transformer = None + self.transformed_data = None + + def _fit(self, data, metadata): + """Fit the synthesizer to the data. + + Args: + data (pd.DataFrame): + The data to fit the synthesizer to. + metadata (sdv.metadata.Metadata): + The metadata describing the data. + """ hyper_transformer = HyperTransformer() - hyper_transformer.detect_initial_config(real_data) + hyper_transformer.detect_initial_config(data) supported_sdtypes = hyper_transformer._get_supported_sdtypes() config = {} table = next(iter(metadata.tables.values())) @@ -44,29 +59,82 @@ def _get_trained_synthesizer(self, real_data, metadata): # This is done to match the behavior of the synthesizer for SDGym <= 0.6.0 columns_to_remove = [ column_name - for column_name, data in real_data.items() - if data.dtype.kind in {'O', 'i', 'b'} + for column_name, column_data in data.items() + if column_data.dtype.kind in {'O', 'i', 'b'} ] hyper_transformer.remove_transformers(columns_to_remove) - hyper_transformer.fit(real_data) - transformed = hyper_transformer.transform(real_data) - - self.length = len(real_data) - return (hyper_transformer, transformed) + hyper_transformer.fit(data) + transformed = hyper_transformer.transform(data) + self.hyper_transformer = hyper_transformer + self.transformed_data = transformed def _sample_from_synthesizer(self, synthesizer, n_samples): - hyper_transformer, transformed = synthesizer + hyper_transformer = synthesizer.hyper_transformer + transformed = synthesizer.transformed_data sampled = pd.DataFrame() for name, column in transformed.items(): kind = column.dtype.kind if kind == 'i': - values = np.random.randint(column.min(), column.max() + 1, size=self.length) + values = np.random.randint( + int(column.min()), int(column.max()) + 1, size=n_samples, dtype=np.int64 + ) elif kind in ['O', 'b']: - values = np.random.choice(column.unique(), size=self.length) + values = np.random.choice(column.unique(), size=n_samples) else: - values = np.random.uniform(column.min(), column.max(), size=self.length) - + values = np.random.uniform(column.min(), column.max(), size=n_samples) sampled[name] = values return hyper_transformer.reverse_transform(sampled) + + +class MultiTableUniformSynthesizer(MultiTableBaselineSynthesizer): + """Multi-table Uniform Synthesizer. + + This synthesizer trains a UniformSynthesizer on each table in the multi-table dataset. + It samples data from each table independently using the corresponding trained synthesizer. + """ + + def __init__(self): + super().__init__() + self.num_rows_per_table = {} + self.table_synthesizers = {} + + def _fit(self, data, metadata): + """Fit the synthesizer to the multi-table data. + + Args: + data (dict): + A dict mapping table name to table data. + metadata (sdv.metadata.MultiTableMetadata): + The multi-table metadata describing the data. + """ + for table_name, table_data in data.items(): + table_metadata = metadata.get_table_metadata(table_name) + synthesizer = UniformSynthesizer() + synthesizer._fit(table_data, table_metadata) + self.num_rows_per_table[table_name] = len(table_data) + self.table_synthesizers[table_name] = synthesizer + + def _sample_from_synthesizer(self, synthesizer, scale): + """Sample data from the provided synthesizer. + + Args: + synthesizer (SDGym synthesizer): + The synthesizer object to sample data from. + scale (float): + The scale of data to sample. + Defaults to 1.0. + + Returns: + dict: A dict mapping table name to the sampled data. + """ + sampled_data = {} + for table_name, table_synthesizer in synthesizer.table_synthesizers.items(): + n_samples = int(synthesizer.num_rows_per_table[table_name] * scale) + sampled_table = UniformSynthesizer().sample_from_synthesizer( + table_synthesizer, n_samples=n_samples + ) + sampled_data[table_name] = sampled_table + + return sampled_data diff --git a/sdgym/synthesizers/utils.py b/sdgym/synthesizers/utils.py index 18cdd9f0..c30f752a 100644 --- a/sdgym/synthesizers/utils.py +++ b/sdgym/synthesizers/utils.py @@ -1,22 +1,6 @@ """Utility functions for synthesizers in SDGym.""" from sdgym.synthesizers.base import BaselineSynthesizer -from sdgym.synthesizers.sdv import _get_all_sdv_synthesizers, _get_sdv_synthesizers - - -def _get_sdgym_synthesizers(): - """Get SDGym synthesizers. - - Returns: - list: - A list of available SDGym synthesizer names. - """ - synthesizers = BaselineSynthesizer._get_supported_synthesizers() - sdv_synthesizer = _get_all_sdv_synthesizers() - sdgym_synthesizer = [ - synthesizer for synthesizer in synthesizers if synthesizer not in sdv_synthesizer - ] - return sorted(sdgym_synthesizer) def get_available_single_table_synthesizers(): @@ -26,9 +10,7 @@ def get_available_single_table_synthesizers(): list: A sorted list of available single-table synthesizer names. """ - sdv_synthesizers = _get_sdv_synthesizers('single_table') - sdgym_synthesizers = _get_sdgym_synthesizers() - return sorted(sdv_synthesizers + sdgym_synthesizers) + return sorted(BaselineSynthesizer._get_supported_synthesizers('single_table')) def get_available_multi_table_synthesizers(): @@ -38,7 +20,7 @@ def get_available_multi_table_synthesizers(): list: A sorted list of available multi-table synthesizer names. """ - return sorted(_get_sdv_synthesizers('multi_table')) + return sorted(BaselineSynthesizer._get_supported_synthesizers('multi_table')) def _get_supported_synthesizers(): @@ -48,4 +30,8 @@ def _get_supported_synthesizers(): list: A list of available SDGym supported synthesizer names. """ - return BaselineSynthesizer._get_supported_synthesizers() + synthesizers = [] + for modality in ['single_table', 'multi_table']: + synthesizers.extend(BaselineSynthesizer._get_supported_synthesizers(modality)) + + return sorted(synthesizers) diff --git a/sdgym/utils.py b/sdgym/utils.py index b6ff1b47..2ff2f9f4 100644 --- a/sdgym/utils.py +++ b/sdgym/utils.py @@ -74,6 +74,7 @@ def get_synthesizers(synthesizers): synthesizer_name = getattr(synthesizer, '__name__', 'undefined') else: synthesizer_name = getattr(type(synthesizer), '__name__', 'undefined') + synthesizers_dicts.append({ 'name': synthesizer_name, 'synthesizer': synthesizer, diff --git a/tests/integration/result_explorer/test_result_explorer.py b/tests/integration/result_explorer/test_result_explorer.py index 188053fd..4a9e3493 100644 --- a/tests/integration/result_explorer/test_result_explorer.py +++ b/tests/integration/result_explorer/test_result_explorer.py @@ -10,16 +10,17 @@ def test_end_to_end_local(tmp_path): """Test the ResultsExplorer end-to-end with local paths.""" # Setup - output_destination = str(tmp_path / 'benchmark_output') + output_destination = tmp_path / 'benchmark_output' + result_explorer_path = output_destination / 'single_table' benchmark_single_table( - output_destination=output_destination, + output_destination=str(output_destination), synthesizers=['GaussianCopulaSynthesizer', 'TVAESynthesizer'], sdv_datasets=['expedia_hotel_logs', 'fake_companies'], ) today = time.strftime('%m_%d_%Y') # Run - result_explorer = ResultsExplorer(output_destination) + result_explorer = ResultsExplorer(str(result_explorer_path)) runs = result_explorer.list() results = result_explorer.load_results(runs[0]) metainfo = result_explorer.load_metainfo(runs[0]) @@ -38,10 +39,11 @@ def test_end_to_end_local(tmp_path): dataset_name='fake_companies', synthesizer_name='TVAESynthesizer', ) + assert isinstance(synthesizer, TVAESynthesizer) new_synthetic_data = synthesizer.sample(num_rows=10) # Assert - expected_results = pd.read_csv(f'{output_destination}/SDGym_results_{today}/results.csv') + expected_results = pd.read_csv(f'{result_explorer_path}/SDGym_results_{today}/results.csv') pd.testing.assert_frame_equal(results, expected_results) assert metainfo[f'run_{today}_0']['jobs'] == [ ['expedia_hotel_logs', 'GaussianCopulaSynthesizer'], @@ -62,7 +64,7 @@ def test_end_to_end_local(tmp_path): def test_summarize(): """Test the `summarize` method.""" # Setup - output_destination = 'tests/integration/result_explorer/_benchmark_results' + output_destination = 'tests/integration/result_explorer/_benchmark_results/' result_explorer = ResultsExplorer(output_destination) # Run diff --git a/tests/integration/synthesizers/test_column.py b/tests/integration/synthesizers/test_column.py index e22c1196..45c29eba 100644 --- a/tests/integration/synthesizers/test_column.py +++ b/tests/integration/synthesizers/test_column.py @@ -29,7 +29,9 @@ def test_column_synthesizer(self): # Run trained_synthesizer = column_synthesizer.get_trained_synthesizer(data, {}) - samples = column_synthesizer.sample_from_synthesizer(trained_synthesizer, n_samples) + samples = column_synthesizer.sample_from_synthesizer( + trained_synthesizer, n_samples=n_samples + ) # Assert assert samples['num'].between(-10, 10).all() @@ -105,7 +107,7 @@ def test_column_synthesizer_sdtypes(self): # Run real_data = pd.DataFrame(data) synthesizer = ColumnSynthesizer().get_trained_synthesizer(real_data, metadata) - hyper_transformer_config = synthesizer[0].get_config() + hyper_transformer_config = synthesizer.hyper_transformer.get_config() # Assert config_sdtypes = hyper_transformer_config['sdtypes'] diff --git a/tests/integration/synthesizers/test_uniform.py b/tests/integration/synthesizers/test_uniform.py index e807d375..366409a2 100644 --- a/tests/integration/synthesizers/test_uniform.py +++ b/tests/integration/synthesizers/test_uniform.py @@ -2,8 +2,10 @@ import numpy as np import pandas as pd +from pandas.api.types import is_numeric_dtype +from sdv.datasets.demo import download_demo -from sdgym.synthesizers.uniform import UniformSynthesizer +from sdgym.synthesizers import MultiTableUniformSynthesizer, UniformSynthesizer def test_uniform_synthesizer(): @@ -30,7 +32,7 @@ def test_uniform_synthesizer(): # Run trained_synthesizer = uniform_synthesizer.get_trained_synthesizer(data, {}) - samples = uniform_synthesizer.sample_from_synthesizer(trained_synthesizer, n_samples) + samples = uniform_synthesizer.sample_from_synthesizer(trained_synthesizer, n_samples=n_samples) # Assert numerical values are uniform min_val = samples['num'].min() @@ -69,3 +71,25 @@ def test_uniform_synthesizer(): assert n_values_interval2 * 0.9 < n_values_interval1 < n_values_interval2 * 1.1 assert n_values_interval3 * 0.9 < n_values_interval1 < n_values_interval3 * 1.1 + + +def test_multitable_uniform_synthesizer_end_to_end(): + """Test the MultiTableUniformSynthesizer end to end.""" + # Setup + data, metadata = download_demo(dataset_name='fake_hotels', modality='multi_table') + synthesizer = MultiTableUniformSynthesizer() + + # Run + trained_synthesizer = synthesizer.get_trained_synthesizer(data, metadata.to_dict()) + sampled_data = synthesizer.sample_from_synthesizer(trained_synthesizer, scale=2) + + # Assert + for table_name, table_data in data.items(): + sampled_table = sampled_data[table_name] + assert len(sampled_table) == len(table_data) * 2 + for column_name in table_data.columns: + original_column = table_data[column_name] + sampled_column = sampled_table[column_name] + if is_numeric_dtype(original_column): + assert sampled_column.min() >= original_column.min() + assert sampled_column.max() <= original_column.max() diff --git a/tests/integration/synthesizers/test_utils.py b/tests/integration/synthesizers/test_utils.py index 60232de0..ad78c519 100644 --- a/tests/integration/synthesizers/test_utils.py +++ b/tests/integration/synthesizers/test_utils.py @@ -28,7 +28,7 @@ def test_get_available_single_table_synthesizers(): def test_get_available_multi_table_synthesizers(): """Test the `get_available_multi_table_synthesizers` method""" # Setup - expected_synthesizers = ['HMASynthesizer'] + expected_synthesizers = ['HMASynthesizer', 'MultiTableUniformSynthesizer'] # Run synthesizers = get_available_multi_table_synthesizers() diff --git a/tests/integration/test_benchmark.py b/tests/integration/test_benchmark.py index b1b884aa..46f71466 100644 --- a/tests/integration/test_benchmark.py +++ b/tests/integration/test_benchmark.py @@ -17,6 +17,7 @@ import sdgym from sdgym import ( + benchmark_multi_table, benchmark_single_table, create_single_table_synthesizer, create_synthesizer_variant, @@ -629,73 +630,82 @@ def sample_from_synthesizer(synthesizer, n_samples): def test_benchmark_single_table_no_warnings(): """Test that the benchmark does not raise any FutureWarnings.""" # Run - with warnings.catch_warnings(record=True) as w: + with warnings.catch_warnings(record=True) as catched_warnings: benchmark_single_table( synthesizers=['GaussianCopulaSynthesizer'], sdv_datasets=['fake_companies'] ) - future_warnings = [warning for warning in w if issubclass(warning.category, FutureWarning)] - assert len(future_warnings) == 0 + + # Assert + future_warnings = [ + warning for warning in catched_warnings if issubclass(warning.category, FutureWarning) + ] + assert len(future_warnings) == 0 def test_benchmark_single_table_with_output_destination(tmp_path): """Test it works with the ``output_destination`` argument.""" # Setup - output_destination = str(tmp_path / 'benchmark_output') + output_destination = tmp_path / 'benchmark_output' today_date = pd.Timestamp.now().strftime('%m_%d_%Y') # Run results = benchmark_single_table( synthesizers=['GaussianCopulaSynthesizer', 'TVAESynthesizer'], sdv_datasets=['fake_companies'], - output_destination=output_destination, + output_destination=str(output_destination), # function may require str ) # Assert - directions = os.listdir(output_destination) - score_saved_separately = pd.DataFrame() - assert directions == [f'SDGym_results_{today_date}'] - subdirections = os.listdir(os.path.join(output_destination, directions[0])) - assert set(subdirections) == { + top_level = os.listdir(output_destination) + assert top_level == ['single_table'] + + second_level = os.listdir(output_destination / 'single_table') + assert second_level == [f'SDGym_results_{today_date}'] + + subdir = output_destination / 'single_table' / f'SDGym_results_{today_date}' + assert set(os.listdir(subdir)) == { 'results.csv', f'fake_companies_{today_date}', 'metainfo.yaml', } - with open(os.path.join(output_destination, directions[0], 'metainfo.yaml'), 'r') as f: + + # Validate metadata + with open(subdir / 'metainfo.yaml', 'r') as f: metadata = yaml.safe_load(f) - assert metadata['completed_date'] is not None - assert metadata['sdgym_version'] == sdgym.__version__ - synthesizer_directions = os.listdir( - os.path.join(output_destination, directions[0], f'fake_companies_{today_date}') - ) - assert set(synthesizer_directions) == { + assert metadata['completed_date'] is not None + assert metadata['sdgym_version'] == sdgym.__version__ + + # Synthesizer directories + synth_dir = subdir / f'fake_companies_{today_date}' + synthesizer_dirs = os.listdir(synth_dir) + assert set(synthesizer_dirs) == { 'TVAESynthesizer', 'GaussianCopulaSynthesizer', 'UniformSynthesizer', } - for synthesizer in sorted(synthesizer_directions): - synthesizer_files = os.listdir( - os.path.join( - output_destination, directions[0], f'fake_companies_{today_date}', synthesizer - ) - ) - assert set(synthesizer_files) == { + + # Validate files in each synthesizer directory + score_saved_separately = pd.DataFrame() + for synthesizer in sorted(synthesizer_dirs): + files = os.listdir(synth_dir / synthesizer) + assert set(files) == { f'{synthesizer}.pkl', f'{synthesizer}_synthetic_data.csv', f'{synthesizer}_benchmark_result.csv', } - score = pd.read_csv( - os.path.join( - output_destination, - directions[0], - f'fake_companies_{today_date}', - synthesizer, - f'{synthesizer}_benchmark_result.csv', - ) - ) + + score_path = synth_dir / synthesizer / f'{synthesizer}_benchmark_result.csv' + score = pd.read_csv(score_path) score_saved_separately = pd.concat([score_saved_separately, score], ignore_index=True) - saved_result = pd.read_csv(f'{output_destination}/SDGym_results_{today_date}/results.csv') + # Load top-level results.csv + saved_results_path = ( + output_destination / 'single_table' / f'SDGym_results_{today_date}' / 'results.csv' + ) + saved_result = pd.read_csv(saved_results_path) + + # Assert Results pd.testing.assert_frame_equal(results, saved_result, check_dtype=False) results_no_adjusted = results.drop(columns=['Adjusted_Total_Time', 'Adjusted_Quality_Score']) pd.testing.assert_frame_equal(results_no_adjusted, score_saved_separately, check_dtype=False) @@ -704,76 +714,218 @@ def test_benchmark_single_table_with_output_destination(tmp_path): def test_benchmark_single_table_with_output_destination_multiple_runs(tmp_path): """Test saving in ``output_destination`` with multiple runs. - Here two benchmark runs are performed with different synthesizers - on the same dataset, and the results are saved in the same output directory. - The directory contains a `results.csv` file with the combined results - and a subdirectory for each synthesizer with its own results. + Two benchmark runs are performed with different synthesizers on the same + dataset, saving results to the same output directory. The directory contains + multiple `results.csv` files and synthesizer subdirectories. """ # Setup - output_destination = str(tmp_path / 'benchmark_output') + output_destination = tmp_path / 'benchmark_output' today_date = pd.Timestamp.now().strftime('%m_%d_%Y') # Run result_1 = benchmark_single_table( synthesizers=['GaussianCopulaSynthesizer'], sdv_datasets=['fake_companies'], - output_destination=output_destination, + output_destination=str(output_destination), ) result_2 = benchmark_single_table( synthesizers=['TVAESynthesizer'], sdv_datasets=['fake_companies'], - output_destination=output_destination, + output_destination=str(output_destination), ) # Assert score_saved_separately = pd.DataFrame() - directions = os.listdir(output_destination) - assert directions == [f'SDGym_results_{today_date}'] - subdirections = os.listdir(os.path.join(output_destination, directions[0])) - assert set(subdirections) == { + + top_level = os.listdir(output_destination) + assert top_level == ['single_table'] + + second_level = os.listdir(output_destination / 'single_table') + assert second_level == [f'SDGym_results_{today_date}'] + + subdir = output_destination / 'single_table' / f'SDGym_results_{today_date}' + assert set(os.listdir(subdir)) == { 'results.csv', 'results(1).csv', f'fake_companies_{today_date}', 'metainfo.yaml', 'metainfo(1).yaml', } - with open(os.path.join(output_destination, directions[0], 'metainfo.yaml'), 'r') as f: + + # Validate metadata + with open(subdir / 'metainfo.yaml', 'r') as f: metadata = yaml.safe_load(f) - assert metadata['completed_date'] is not None - assert metadata['sdgym_version'] == sdgym.__version__ - synthesizer_directions = os.listdir( - os.path.join(output_destination, directions[0], f'fake_companies_{today_date}') - ) - assert set(synthesizer_directions) == { + assert metadata['completed_date'] is not None + assert metadata['sdgym_version'] == sdgym.__version__ + + # Synthesizer directories + synth_parent = subdir / f'fake_companies_{today_date}' + synthesizer_dirs = os.listdir(synth_parent) + + # Assert Synthesizer directories + assert set(synthesizer_dirs) == { 'TVAESynthesizer(1)', 'GaussianCopulaSynthesizer', 'UniformSynthesizer', 'UniformSynthesizer(1)', } - for synthesizer in sorted(synthesizer_directions): - synthesizer_files = os.listdir( - os.path.join( - output_destination, directions[0], f'fake_companies_{today_date}', synthesizer - ) - ) - assert set(synthesizer_files) == { + + # Validate each synthesizer directory + for synthesizer in sorted(synthesizer_dirs): + synth_path = synth_parent / synthesizer + + synth_files = os.listdir(synth_path) + assert set(synth_files) == { f'{synthesizer}.pkl', f'{synthesizer}_synthetic_data.csv', f'{synthesizer}_benchmark_result.csv', } - score = pd.read_csv( - os.path.join( - output_destination, - directions[0], - f'fake_companies_{today_date}', - synthesizer, - f'{synthesizer}_benchmark_result.csv', - ) - ) + + score = pd.read_csv(synth_path / f'{synthesizer}_benchmark_result.csv') + score_saved_separately = pd.concat([score_saved_separately, score], ignore_index=True) + + # Load saved results + saved_result_1 = pd.read_csv(subdir / 'results.csv') + saved_result_2 = pd.read_csv(subdir / 'results(1).csv') + + # Assert results + pd.testing.assert_frame_equal(result_1, saved_result_1, check_dtype=False) + pd.testing.assert_frame_equal(result_2, saved_result_2, check_dtype=False) + + +def test_benchmark_multi_table_with_output_destination_multiple_runs(tmp_path): + """Test saving in ``output_destination`` with multiple runs in multi-table mode. + + Two benchmark runs are performed with HMASynthesizer on the same multi-table + dataset, saving results to the same output directory. The directory contains + multiple `results*.csv` files, metainfo files, and synthesizer subdirectories. + """ + # Setup + output_destination = tmp_path / 'benchmark_output' + today_date = pd.Timestamp.now().strftime('%m_%d_%Y') + + # Run 1 + result_1 = benchmark_multi_table( + synthesizers=['HMASynthesizer'], + sdv_datasets=['fake_hotels'], + output_destination=str(output_destination), + ) + + # Run 2 + result_2 = benchmark_multi_table( + synthesizers=['HMASynthesizer'], + sdv_datasets=['fake_hotels'], + output_destination=str(output_destination), + ) + + # Assert + score_saved_separately = pd.DataFrame() + + # Top level directory + top_level = os.listdir(output_destination) + assert top_level == ['multi_table'] + + # Second level + second_level = os.listdir(output_destination / 'multi_table') + assert second_level == [f'SDGym_results_{today_date}'] + + # SDGym results folder + subdir = output_destination / 'multi_table' / f'SDGym_results_{today_date}' + assert set(os.listdir(subdir)) == { + 'results.csv', + 'results(1).csv', + f'fake_hotels_{today_date}', + 'metainfo.yaml', + 'metainfo(1).yaml', + } + + # Validate metadata + with open(subdir / 'metainfo.yaml', 'r') as f: + metadata = yaml.safe_load(f) + + assert metadata['completed_date'] is not None + assert metadata['sdgym_version'] == sdgym.__version__ + assert metadata['modality'] == 'multi_table' + + # Synthesizer folders + synth_parent = subdir / f'fake_hotels_{today_date}' + synthesizer_dirs = os.listdir(synth_parent) + + assert set(synthesizer_dirs) == { + 'HMASynthesizer', + 'HMASynthesizer(1)', + 'MultiTableUniformSynthesizer', + 'MultiTableUniformSynthesizer(1)', + } + + # Validate each synthesizer directory + for synthesizer in sorted(synthesizer_dirs): + synth_path = synth_parent / synthesizer + + synth_files = os.listdir(synth_path) + assert set(synth_files) == { + f'{synthesizer}.pkl', + f'{synthesizer}_synthetic_data.zip', + f'{synthesizer}_benchmark_result.csv', + } + + score = pd.read_csv(synth_path / f'{synthesizer}_benchmark_result.csv') score_saved_separately = pd.concat([score_saved_separately, score], ignore_index=True) - saved_result_1 = pd.read_csv(f'{output_destination}/SDGym_results_{today_date}/results.csv') - saved_result_2 = pd.read_csv(f'{output_destination}/SDGym_results_{today_date}/results(1).csv') + # Load results for both runs + saved_result_1 = pd.read_csv(subdir / 'results.csv') + saved_result_2 = pd.read_csv(subdir / 'results(1).csv') + + # Validate the stored results match returned results pd.testing.assert_frame_equal(result_1, saved_result_1, check_dtype=False) pd.testing.assert_frame_equal(result_2, saved_result_2, check_dtype=False) + + +def test_benchmark_multi_table_basic_synthesizers(): + """Integration test: run HMASynthesizer + MultiTableUniformSynthesizer on fake_hotels.""" + output = benchmark_multi_table( + synthesizers=['HMASynthesizer', 'MultiTableUniformSynthesizer'], + sdv_datasets=['fake_hotels'], + compute_quality_score=True, + compute_diagnostic_score=True, + limit_dataset_size=True, + show_progress=False, + timeout=30, + ) + + # Assert + assert isinstance(output, pd.DataFrame) + assert not output.empty + + # Required SDGym benchmark output columns + for col in [ + 'Synthesizer', + 'Train_Time', + 'Sample_Time', + 'Quality_Score', + 'Diagnostic_Score', + ]: + assert col in output.columns + + synths = sorted(output['Synthesizer'].unique()) + assert synths == [ + 'HMASynthesizer', + 'MultiTableUniformSynthesizer', + ] + + diagnostic_rank = ( + output.groupby('Synthesizer').Diagnostic_Score.mean().sort_values().index.tolist() + ) + + assert diagnostic_rank == [ + 'MultiTableUniformSynthesizer', + 'HMASynthesizer', + ] + + quality_rank = output.groupby('Synthesizer').Quality_Score.mean().sort_values().index.tolist() + + assert quality_rank == [ + 'MultiTableUniformSynthesizer', + 'HMASynthesizer', + ] diff --git a/tests/unit/synthesizers/test_base.py b/tests/unit/synthesizers/test_base.py index 73f1be2b..e1f620d4 100644 --- a/tests/unit/synthesizers/test_base.py +++ b/tests/unit/synthesizers/test_base.py @@ -1,23 +1,61 @@ +import re import warnings from unittest.mock import Mock, patch import pandas as pd +import pytest from sdv.metadata import Metadata -from sdgym.synthesizers.base import BaselineSynthesizer +from sdgym.synthesizers.base import ( + BaselineSynthesizer, + MultiTableBaselineSynthesizer, + _is_valid_modality, + _validate_modality, +) + + +@pytest.mark.parametrize( + 'modality, result', + [ + ('single_table', True), + ('multi_table', True), + ('invalid_modality', False), + ], +) +def test__is_valid_modality(modality, result): + """Test the `_is_valid_modality` method.""" + assert _is_valid_modality(modality) == result + + +def test__validate_modality(): + """Test the `_validate_modality` method.""" + # Setup + valid_modality = 'single_table' + invalid_modality = 'invalid_modality' + expected_error = re.escape( + f"Modality '{invalid_modality}' is not valid. Must be either " + "'single_table' or 'multi_table'." + ) + + # Run and Assert + _validate_modality(valid_modality) + with pytest.raises(ValueError, match=expected_error): + _validate_modality(invalid_modality) class TestBaselineSynthesizer: - @patch('sdgym.synthesizers.utils.BaselineSynthesizer.get_subclasses') - def test__get_supported_synthesizers_mock(self, mock_get_subclasses): + @patch('sdgym.synthesizers.base.BaselineSynthesizer.get_subclasses') + @patch('sdgym.synthesizers.base._validate_modality') + def test__get_supported_synthesizers_mock(self, mock_validate_modality, mock_get_subclasses): """Test the `_get_supported_synthesizers` method with mocks.""" # Setup mock_get_subclasses.return_value = { - 'Variant:ColumnSynthesizer': Mock(_NATIVELY_SUPPORTED=False), - 'Custom:MySynthesizer': Mock(_NATIVELY_SUPPORTED=False), - 'ColumnSynthesizer': Mock(_NATIVELY_SUPPORTED=True), - 'UniformSynthesizer': Mock(_NATIVELY_SUPPORTED=True), - 'DataIdentity': Mock(_NATIVELY_SUPPORTED=True), + 'Variant:Synthesizer': Mock(_NATIVELY_SUPPORTED=False, _MODALITY_FLAG='single_table'), + 'Custom:MySynthesizer': Mock(_NATIVELY_SUPPORTED=False, _MODALITY_FLAG='single_table'), + 'ColumnSynthesizer': Mock(_NATIVELY_SUPPORTED=True, _MODALITY_FLAG='single_table'), + 'UniformSynthesizer': Mock(_NATIVELY_SUPPORTED=True, _MODALITY_FLAG='single_table'), + 'MultiTableSynthesizer': Mock(_NATIVELY_SUPPORTED=True, _MODALITY_FLAG='multi_table'), + 'DataIdentity': Mock(_NATIVELY_SUPPORTED=True, _MODALITY_FLAG='single_table'), } expected_synthesizers = [ 'ColumnSynthesizer', @@ -26,9 +64,11 @@ def test__get_supported_synthesizers_mock(self, mock_get_subclasses): ] # Run - synthesizers = BaselineSynthesizer._get_supported_synthesizers() + synthesizers = BaselineSynthesizer._get_supported_synthesizers('single_table') # Assert + mock_validate_modality.assert_called_once_with('single_table') + mock_get_subclasses.assert_called_once_with(include_parents=True) assert synthesizers == expected_synthesizers def test_get_trained_synthesizer(self): @@ -58,3 +98,45 @@ def test_get_trained_synthesizer(self): assert args[1].to_dict() == metadata.to_dict() assert isinstance(args[1], Metadata) assert instance._get_trained_synthesizer.return_value == mock_synthesizer + + +class TestMultiTableBaselineSynthesizer: + @pytest.mark.parametrize( + 'scale, expected_scale', + [ + (None, 1.0), + (2.0, 2.0), + ], + ) + def test_sample_from_synthesizer_valid(self, scale, expected_scale): + """Test that valid calls return correct values and call underlying method.""" + synthesizer = MultiTableBaselineSynthesizer() + mock_synthesizer = Mock() + synthesizer._sample_from_synthesizer = Mock(return_value='sampled_data') + + # Run + if scale is None: + result = synthesizer.sample_from_synthesizer(mock_synthesizer) + else: + result = synthesizer.sample_from_synthesizer(mock_synthesizer, scale) + + # Assert call + synthesizer._sample_from_synthesizer.assert_called_with(mock_synthesizer, expected_scale) + + assert result == 'sampled_data' + assert synthesizer._MODALITY_FLAG == 'multi_table' + + def test_sample_from_synthesizer_raises_on_unexpected_kwarg(self): + """Test that passing n_samples raises a TypeError.""" + synthesizer = MultiTableBaselineSynthesizer() + mock_synthesizer = Mock() + + expected_error = re.escape( + "sample_from_synthesizer() got an unexpected keyword argument 'n_samples'" + ) + + with pytest.raises(TypeError, match=expected_error): + synthesizer.sample_from_synthesizer( + mock_synthesizer, + n_samples=10, + ) diff --git a/tests/unit/synthesizers/test_generate.py b/tests/unit/synthesizers/test_generate.py index 45e25a6d..72ad0c81 100644 --- a/tests/unit/synthesizers/test_generate.py +++ b/tests/unit/synthesizers/test_generate.py @@ -55,7 +55,7 @@ def test_create_sdv_variant_synthesizer(): # Assert assert out.__name__ == 'Variant:test_synth' - assert out.modality == 'single_table' + assert out._MODALITY_FLAG == 'single_table' assert out._MODEL_KWARGS == synthesizer_parameters assert out.SDV_NAME == synthesizer_class assert out._NATIVELY_SUPPORTED is False @@ -85,7 +85,7 @@ def test_create_sdv_variant_synthesizer_multi_table(): # Assert assert out.__name__ == 'Variant:test_synth' - assert out.modality == 'multi_table' + assert out._MODALITY_FLAG == 'multi_table' assert out._MODEL_KWARGS == synthesizer_parameters assert out.SDV_NAME == synthesizer_class assert out._NATIVELY_SUPPORTED is False @@ -103,13 +103,15 @@ def test__create_synthesizer_class(): 'test_synth', get_trained_synthesizer_fn, sample_fn, - sample_arg_name='num_samples', + modality='single_table', ) # Assert assert out.__name__ == 'Custom:test_synth' assert hasattr(out, 'get_trained_synthesizer') assert hasattr(out, 'sample_from_synthesizer') + assert out._NATIVELY_SUPPORTED is False + assert out._MODALITY_FLAG == 'single_table' @patch('sdgym.synthesizers.generate._create_synthesizer_class') @@ -128,7 +130,7 @@ def test_create_single_table_synthesizer_mock(mock_create_class): 'test_synth', get_trained_synthesizer_fn, sample_fn, - sample_arg_name='num_samples', + modality='single_table', ) assert out == 'synthesizer_class' @@ -149,6 +151,6 @@ def test_create_multi_table_synthesizer_mock(mock_create_class): 'test_synth', get_trained_synthesizer_fn, sample_fn, - sample_arg_name='scale', + modality='multi_table', ) assert out == 'synthesizer_class' diff --git a/tests/unit/synthesizers/test_realtabformer.py b/tests/unit/synthesizers/test_realtabformer.py index 89146c1a..a66a8cb8 100644 --- a/tests/unit/synthesizers/test_realtabformer.py +++ b/tests/unit/synthesizers/test_realtabformer.py @@ -43,13 +43,17 @@ def test__get_trained_synthesizer(self, mock_real_tab_former): # Assert mock_real_tab_former.assert_called_once_with(model_type='tabular') mock_model.fit.assert_called_once_with(data) - assert result == mock_model, 'Expected the trained model to be returned.' + assert result._internal_synthesizer == mock_model + assert isinstance(result, RealTabFormerSynthesizer) def test__sample_from_synthesizer(self): """Test _sample_from_synthesizer generates data with the specified sample size.""" # Setup trained_model = MagicMock() - trained_model.sample.return_value = MagicMock(shape=(10, 5)) # Mock sample data shape + trained_model._internal_synthesizer = MagicMock() + trained_model._internal_synthesizer.sample.return_value = MagicMock( + shape=(10, 5) + ) # Mock sample data shape n_sample = 10 synthesizer = RealTabFormerSynthesizer() @@ -57,7 +61,7 @@ def test__sample_from_synthesizer(self): synthetic_data = synthesizer._sample_from_synthesizer(trained_model, n_sample) # Assert - trained_model.sample.assert_called_once_with(n_sample) + trained_model._internal_synthesizer.sample.assert_called_once_with(n_sample) assert synthetic_data.shape[0] == n_sample, ( f'Expected {n_sample} rows, but got {synthetic_data.shape[0]}' ) diff --git a/tests/unit/synthesizers/test_sdv.py b/tests/unit/synthesizers/test_sdv.py index c0153948..cd6cb65a 100644 --- a/tests/unit/synthesizers/test_sdv.py +++ b/tests/unit/synthesizers/test_sdv.py @@ -9,37 +9,16 @@ from sdgym.synthesizers.base import BaselineSynthesizer from sdgym.synthesizers.sdv import ( _create_sdv_class, + _fit, _get_all_sdv_synthesizers, _get_modality, _get_sdv_synthesizers, - _get_trained_synthesizer, _retrieve_sdv_class, _sample_from_synthesizer, - _validate_modality, create_sdv_synthesizer_class, ) -def test__validate_modality(): - """Test the `_validate_modality` method.""" - # Setup - valid_modalities = ['single_table', 'multi_table'] - - # Run and Assert - for modality in valid_modalities: - _validate_modality(modality) - - -def test__validate_modality_invalid(): - """Test the `_validate_modality` method with invalid modality.""" - # Setup - expected_error = re.escape("`modality` must be one of 'single_table' or 'multi_table'.") - - # Run and Assert - with pytest.raises(ValueError, match=expected_error): - _validate_modality('invalid_modality') - - def test__get_sdv_synthesizers(): """Test the `_get_sdv_synthesizers` method.""" # Setup @@ -79,8 +58,8 @@ def test__get_all_sdv_synthesizers(): @patch('sdgym.synthesizers.sdv.LOGGER') -def test__get_trained_synthesizer(mock_logger): - """Test the `_get_trained_synthesizer` method.""" +def test__fit(mock_logger): + """Test the `_fit` method.""" # Setup data = pd.DataFrame({ 'column1': [1, 2, 3, 4, 5], @@ -99,16 +78,17 @@ def test__get_trained_synthesizer(mock_logger): synthesizer = Mock() synthesizer.__class__.__name__ = 'GaussianCopulaClass' synthesizer._MODEL_KWARGS = {'enforce_min_max_values': False} - synthesizer.modality = 'single_table' + synthesizer._MODALITY_FLAG = 'single_table' synthesizer.SDV_NAME = 'GaussianCopulaSynthesizer' # Run - valid_model = _get_trained_synthesizer(synthesizer, data, metadata) + _fit(synthesizer, data, metadata) # Assert mock_logger.info.assert_called_with('Fitting %s', 'GaussianCopulaClass') - assert isinstance(valid_model, GaussianCopulaSynthesizer) - assert valid_model.enforce_min_max_values is False + assert isinstance(synthesizer._internal_synthesizer, GaussianCopulaSynthesizer) + assert synthesizer._internal_synthesizer.enforce_min_max_values is False + assert synthesizer._internal_synthesizer._fitted is True @patch('sdgym.synthesizers.sdv.LOGGER') @@ -121,9 +101,10 @@ def test__sample_from_synthesizer(mock_logger): }) base_synthesizer = Mock() base_synthesizer.__class__.__name__ = 'GaussianCopulaSynthesizer' - base_synthesizer.modality = 'single_table' + base_synthesizer._MODALITY_FLAG = 'single_table' synthesizer = Mock() - synthesizer.sample.return_value = data + synthesizer._internal_synthesizer = Mock() + synthesizer._internal_synthesizer.sample.return_value = data n_samples = 3 # Run @@ -132,7 +113,7 @@ def test__sample_from_synthesizer(mock_logger): # Assert mock_logger.info.assert_called_with('Sampling %s', 'GaussianCopulaSynthesizer') pd.testing.assert_frame_equal(sampled_data, data) - synthesizer.sample.assert_called_once_with(num_rows=n_samples) + synthesizer._internal_synthesizer.sample.assert_called_once_with(num_rows=n_samples) @patch('sdgym.synthesizers.sdv.sys.modules') @@ -187,15 +168,15 @@ def test__create_sdv_class_mock(mock_get_modality, mock_sys_modules): # Assert assert synt_class.__name__ == sdv_name - assert synt_class.modality == 'single_table' + assert synt_class._MODALITY_FLAG == 'single_table' assert synt_class._MODEL_KWARGS == {} assert synt_class.SDV_NAME == sdv_name assert issubclass(synt_class, BaselineSynthesizer) - assert getattr(synt_class, '_get_trained_synthesizer') is _get_trained_synthesizer + assert getattr(synt_class, '_fit') is _fit assert getattr(synt_class, '_sample_from_synthesizer') is _sample_from_synthesizer assert getattr(fake_module, sdv_name) is synt_class - assert instance._get_trained_synthesizer.__self__ is instance - assert instance._get_trained_synthesizer.__func__ is _get_trained_synthesizer + assert instance._fit.__self__ is instance + assert instance._fit.__func__ is _fit assert instance._sample_from_synthesizer.__self__ is instance assert instance._sample_from_synthesizer.__func__ is _sample_from_synthesizer assert instance.SDV_NAME == sdv_name @@ -212,7 +193,7 @@ def test__create_sdv_class(): # Assert assert synthesizer_class.__name__ == sdv_name - assert synthesizer_class.modality == 'single_table' + assert synthesizer_class._MODALITY_FLAG == 'single_table' assert synthesizer_class._MODEL_KWARGS == {} assert issubclass(synthesizer_class, BaselineSynthesizer) diff --git a/tests/unit/synthesizers/test_uniform.py b/tests/unit/synthesizers/test_uniform.py index 8269938c..d7258cd3 100644 --- a/tests/unit/synthesizers/test_uniform.py +++ b/tests/unit/synthesizers/test_uniform.py @@ -1,12 +1,17 @@ +from unittest.mock import Mock, call, patch + import numpy as np import pandas as pd +from rdt import HyperTransformer +from sdv.metadata import Metadata -from sdgym.synthesizers.uniform import UniformSynthesizer +from sdgym.synthesizers.uniform import MultiTableUniformSynthesizer, UniformSynthesizer class TestUniformSynthesizer: def test_uniform_synthesizer_sdtypes(self): """Ensure that sdtypes uniform are taken from metadata instead of inferred.""" + # Setup uniform_synthesizer = UniformSynthesizer() metadata = { 'primary_key': 'guest_email', @@ -67,8 +72,11 @@ def test_uniform_synthesizer_sdtypes(self): } real_data = pd.DataFrame(data) + + # Run synthesizer = uniform_synthesizer.get_trained_synthesizer(real_data, metadata) - hyper_transformer_config = synthesizer[0].get_config() + + hyper_transformer_config = synthesizer.hyper_transformer.get_config() config_sdtypes = hyper_transformer_config['sdtypes'] unknown_sdtypes = ['email', 'credit_card_number', 'address'] for column in metadata['columns']: @@ -78,3 +86,221 @@ def test_uniform_synthesizer_sdtypes(self): assert metadata_sdtype == config_sdtypes[column] else: assert config_sdtypes[column] == 'pii' + + +class TestMultiTableUniformSynthesizer: + @patch('sdgym.synthesizers.uniform.MultiTableBaselineSynthesizer.__init__') + def test__init__(self, mock_baseline_init): + """Test the `__init__` method.""" + # Run + synthesizer = MultiTableUniformSynthesizer() + + # Assert + mock_baseline_init.assert_called_once() + assert synthesizer.num_rows_per_table == {} + + @patch('sdgym.synthesizers.uniform.UniformSynthesizer._fit') + def test__fit_mock(self, mock_uniform_fit): + """Test the `fit` method with mocking.""" + # Setup + synthesizer = MultiTableUniformSynthesizer() + data = { + 'table1': pd.DataFrame({ + 'col1': [1, 2, 3], + 'col2': ['A', 'B', 'C'], + }), + 'table2': pd.DataFrame({ + 'col3': [10.0, 20.0, 30.0], + 'col4': [True, False, True], + }), + } + metadata = Mock() + st_metadatas = [ + { + 'primary_key': 'col1', + 'columns': { + 'col1': {'sdtype': 'numerical'}, + 'col2': {'sdtype': 'categorical'}, + }, + }, + { + 'primary_key': 'col3', + 'columns': { + 'col3': {'sdtype': 'numerical'}, + 'col4': {'sdtype': 'boolean'}, + }, + }, + ] + metadata.get_table_metadata.side_effect = st_metadatas + + # Run + synthesizer._fit(data, metadata) + + # Assert + metadata.get_table_metadata.assert_has_calls([ + call('table1'), + call('table2'), + ]) + mock_uniform_fit.assert_has_calls([ + call(data['table1'], st_metadatas[0]), + call(data['table2'], st_metadatas[1]), + ]) + assert synthesizer.num_rows_per_table == { + 'table1': 3, + 'table2': 3, + } + for table_name, table_synthesizer in synthesizer.table_synthesizers.items(): + assert table_name in ('table1', 'table2') + assert isinstance(table_synthesizer, UniformSynthesizer) + + def test__get_trained_synthesizer(self): + """Test the `_get_trained_synthesizer` method.""" + # Setup + synthesizer = MultiTableUniformSynthesizer() + data = { + 'table1': pd.DataFrame({ + 'col1': [1, 2, 3, 4, 5], + 'col2': ['A', 'B', 'C', 'D', 'E'], + }), + 'table2': pd.DataFrame({ + 'col3': [10.0, 20.0, 30.0], + 'col4': [True, False, True], + }), + } + metadata = Metadata.load_from_dict({ + 'tables': { + 'table1': { + 'columns': { + 'col1': {'sdtype': 'numerical'}, + 'col2': {'sdtype': 'categorical'}, + }, + 'primary_key': 'col1', + }, + 'table2': { + 'columns': { + 'col3': {'sdtype': 'numerical'}, + 'col4': {'sdtype': 'boolean'}, + }, + 'primary_key': 'col3', + }, + }, + 'relationships': [], + }) + + # Run + trained_synthesizer = synthesizer._get_trained_synthesizer(data, metadata) + + # Assert + assert trained_synthesizer.num_rows_per_table == { + 'table1': 5, + 'table2': 3, + } + assert set(trained_synthesizer.table_synthesizers.keys()) == {'table1', 'table2'} + for table_name, table_synthesizer in trained_synthesizer.table_synthesizers.items(): + hyper_transformer = table_synthesizer.hyper_transformer + transformed = table_synthesizer.transformed_data + assert isinstance(hyper_transformer, HyperTransformer) + assert isinstance(transformed, pd.DataFrame) + assert set(transformed.columns) == set(data[table_name].columns) + + @patch('sdgym.synthesizers.uniform.UniformSynthesizer.sample_from_synthesizer') + def test_sample_from_synthesizer_mock(self, mock_sample_from_synthesizer): + """Test the `sample_from_synthesizer` method with mocking.""" + # Setup + synthesizer = MultiTableUniformSynthesizer() + trained_synthesizer = MultiTableUniformSynthesizer() + trained_synthesizer.num_rows_per_table = { + 'table1': 3, + 'table2': 2, + } + synthesizer_table1 = Mock() + synthesizer_table2 = Mock() + trained_synthesizer.table_synthesizers = { + 'table1': synthesizer_table1, + 'table2': synthesizer_table2, + } + mock_sample_from_synthesizer.side_effect = [ + 'sampled_data_table1', + 'sampled_data_table2', + ] + scale = 2 + + # Run + sampled_data = synthesizer.sample_from_synthesizer(trained_synthesizer, scale=scale) + + # Assert + assert sampled_data == { + 'table1': 'sampled_data_table1', + 'table2': 'sampled_data_table2', + } + mock_sample_from_synthesizer.assert_has_calls([ + call(synthesizer_table1, n_samples=6), + call(synthesizer_table2, n_samples=4), + ]) + + def test_sample_from_synthesizer(self): + """Test the `sample_from_synthesizer` method.""" + # Setup + np.random.seed(0) + synthesizer = MultiTableUniformSynthesizer() + data = { + 'table1': pd.DataFrame({ + 'col1': [1, 2, 3, 4, 5], + 'col2': ['A', 'B', 'C', 'D', 'E'], + }), + 'table2': pd.DataFrame({ + 'col3': [10, 20, 30], + 'col4': [True, False, True], + }), + } + table_1 = UniformSynthesizer() + table_1._fit( + data['table1'], + Metadata.load_from_dict({ + 'columns': { + 'col1': {'sdtype': 'numerical'}, + 'col2': {'sdtype': 'categorical'}, + }, + }), + ) + table_2 = UniformSynthesizer() + table_2._fit( + data['table2'], + Metadata.load_from_dict({ + 'columns': { + 'col3': {'sdtype': 'numerical'}, + 'col4': {'sdtype': 'boolean'}, + }, + }), + ) + trained_synthesizer = MultiTableUniformSynthesizer() + + trained_synthesizer.table_synthesizers = { + 'table1': table_1, + 'table2': table_2, + } + trained_synthesizer.num_rows_per_table = { + 'table1': 5, + 'table2': 3, + } + scale = 2 + expected_data = { + 'table1': pd.DataFrame({ + 'col1': [5, 1, 4, 4, 4, 2, 4, 3, 5, 1], + 'col2': ['A', 'E', 'C', 'B', 'A', 'B', 'B', 'A', 'B', 'E'], + }), + 'table2': pd.DataFrame({ + 'col3': [29, 26, 29, 15, 25, 25], + 'col4': [True, True, False, True, False, False], + }), + } + + # Run + sampled_data = synthesizer.sample_from_synthesizer(trained_synthesizer, scale=scale) + + # Assert + for table_name, table_data in sampled_data.items(): + pd.testing.assert_frame_equal( + table_data, + expected_data[table_name], + ) diff --git a/tests/unit/synthesizers/test_utils.py b/tests/unit/synthesizers/test_utils.py index 77d19dd1..0881994a 100644 --- a/tests/unit/synthesizers/test_utils.py +++ b/tests/unit/synthesizers/test_utils.py @@ -1,35 +1,4 @@ -from unittest.mock import patch - -from sdgym.synthesizers.utils import _get_sdgym_synthesizers, _get_supported_synthesizers - - -@patch('sdgym.synthesizers.utils.BaselineSynthesizer._get_supported_synthesizers') -def test__get_sdgym_synthesizers(mock_get_supported_synthesizers): - """Test the `_get_sdgym_synthesizers` method.""" - # Setup - mock_get_supported_synthesizers.return_value = [ - 'ColumnSynthesizer', - 'UniformSynthesizer', - 'DataIdentity', - 'RealTabFormerSynthesizer', - 'CTGANSynthesizer', - 'CopulaGANSynthesizer', - 'GaussianCopulaSynthesizer', - 'HMASynthesizer', - 'TVAESynthesizer', - ] - expected_synthesizers = [ - 'ColumnSynthesizer', - 'DataIdentity', - 'RealTabFormerSynthesizer', - 'UniformSynthesizer', - ] - - # Run - synthesizers = _get_sdgym_synthesizers() - - # Assert - assert synthesizers == expected_synthesizers +from sdgym.synthesizers.utils import _get_supported_synthesizers def test__get_supported_synthesizers(): @@ -42,6 +11,7 @@ def test__get_supported_synthesizers(): 'DataIdentity', 'GaussianCopulaSynthesizer', 'HMASynthesizer', + 'MultiTableUniformSynthesizer', 'RealTabFormerSynthesizer', 'TVAESynthesizer', 'UniformSynthesizer', diff --git a/tests/unit/test__dataset_utils.py b/tests/unit/test__dataset_utils.py new file mode 100644 index 00000000..f979bf1c --- /dev/null +++ b/tests/unit/test__dataset_utils.py @@ -0,0 +1,200 @@ +import json +from unittest.mock import MagicMock, patch + +import numpy as np +import pandas as pd +import pytest + +from sdgym._dataset_utils import ( + _get_dataset_subset, + _get_multi_table_dataset_subset, + _parse_numeric_value, + _read_csv_from_zip, + _read_metadata_json, + _read_zipped_data, +) + + +@pytest.mark.parametrize( + 'value,expected', + [ + ('3.14', 3.14), + ('not-a-number', np.nan), + (None, np.nan), + ], +) +def test__parse_numeric_value(value, expected): + """Test numeric parsing with fallback to NaN.""" + # Setup / Run + result = _parse_numeric_value(value, 'dataset', 'field') + + # Assert + if np.isnan(expected): + assert np.isnan(result) + else: + assert result == expected + + +@patch('sdgym._dataset_utils.poc.get_random_subset') +@patch('sdgym._dataset_utils.Metadata') +def test__get_multi_table_dataset_subset(mock_metadata, mock_subset): + """Test multi-table subset selection calls SDV and trims columns.""" + # Setup + df_main = pd.DataFrame({'a': [1, 2], 'b': [3, 4]}) + df_other = pd.DataFrame({'x': [5, 6], 'y': [7, 8]}) + + data = {'main': df_main, 'other': df_other} + + metadata_dict = { + 'tables': { + 'main': {'columns': {'a': {}, 'b': {}}}, + 'other': {'columns': {'x': {}, 'y': {}}}, + } + } + + mock_meta_obj = MagicMock() + mock_meta_obj.tables = { + 'main': MagicMock(columns={'a': {}, 'b': {}}), + 'other': MagicMock(columns={'x': {}, 'y': {}}), + } + mock_meta_obj._get_all_keys.return_value = [] + mock_metadata.load_from_dict.return_value = mock_meta_obj + + mock_subset.return_value = {'main': df_main[:1], 'other': df_other[:1]} + + # Run + result_data, result_meta = _get_multi_table_dataset_subset(data, metadata_dict) + + # Assert + assert 'main' in result_data + assert 'other' in result_data + mock_subset.assert_called_once() + + +def test__get_dataset_subset_single_table(): + """Test tabular dataset subset reduces rows and columns.""" + # Setup + df = pd.DataFrame({f'c{i}': range(2000) for i in range(15)}) + metadata = {'tables': {'table': {'columns': {f'c{i}': {} for i in range(15)}}}} + + # Run + result_df, result_meta = _get_dataset_subset(df, metadata, modality='regular') + + # Assert + assert len(result_df) <= 1000 + assert len(result_df.columns) == 10 + assert 'tables' in result_meta + + +def test__get_dataset_subset_sequential(): + """Test sequential dataset preserves mandatory columns.""" + # Setup + df = pd.DataFrame({ + 'seq_id': range(20), + 'seq_key': range(20), + **{f'c{i}': range(20) for i in range(20)}, + }) + + metadata = { + 'tables': { + 'table': { + 'columns': {col: {'sdtype': 'numerical'} for col in df.columns.to_list()}, + 'sequence_index': 'seq_id', + 'sequence_key': 'seq_key', + } + } + } + + # Run + subset_df, _ = _get_dataset_subset(df, metadata, modality='sequential') + + # Assert + assert 'seq_id' in subset_df.columns + assert 'seq_key' in subset_df.columns + assert len(subset_df.columns) <= 10 + + +@patch('sdgym._dataset_utils._get_multi_table_dataset_subset') +def test__get_dataset_subset_multi_table(mock_multi): + """Test multi-table dispatch calls the correct function.""" + # Setup + data = {'table': pd.DataFrame({'a': [1, 2]})} + metadata = {'tables': {}} + mock_multi.return_value = ('DATA', 'META') + + # Run + out_data, out_meta = _get_dataset_subset(data, metadata, modality='multi_table') + + # Assert + assert out_data == 'DATA' + assert out_meta == 'META' + mock_multi.assert_called_once() + + +@patch('sdgym._dataset_utils._read_csv_from_zip') +def test__read_zipped_data_multitable(mock_read): + """Test zipped CSV reading returns a dict for multi-table.""" + # Setup + mock_read.return_value = pd.DataFrame({'a': [1]}) + + mock_zip = MagicMock() + mock_zip.__enter__.return_value = mock_zip + mock_zip.namelist.return_value = ['table1.csv', 'table2.csv'] + + # Run + with patch('sdgym._dataset_utils.ZipFile', return_value=mock_zip): + data_multi = _read_zipped_data('fake.zip', modality='multi_table') + + # Assert + assert isinstance(data_multi, dict) + assert mock_read.call_count == 2 + + +@patch('sdgym._dataset_utils._read_csv_from_zip') +def test__read_zipped_data_single(mock_read): + """Test zipped CSV reading returns a DataFrame for single-table.""" + # Setup + mock_read.return_value = pd.DataFrame({'a': [1]}) + + mock_zip = MagicMock() + mock_zip.__enter__.return_value = mock_zip + mock_zip.namelist.return_value = ['table1.csv'] + + # Run + with patch('sdgym._dataset_utils.ZipFile', return_value=mock_zip): + data_single = _read_zipped_data('fake.zip', modality='single') + + # Assert + assert isinstance(data_single, pd.DataFrame) + assert mock_read.call_count == 1 + + +@patch('sdgym._dataset_utils.pd') +def test__read_csv_from_zip(mock_pd): + """Test CSV is read from zip and returned as DataFrame.""" + # Setup + csv_bytes = b'a,b\n1,2\n3,4\n' + returned_bytes = csv_bytes.decode().splitlines() + mock_zip = MagicMock() + mock_zip.open.return_value.__enter__.return_value = returned_bytes + + # Run + result = _read_csv_from_zip(mock_zip, 'fake.csv') + + # Assert + mock_pd.read_csv.assert_called_once_with(returned_bytes, low_memory=False) + assert result == mock_pd.read_csv.return_value + + +def test__read_metadata_json(tmp_path): + """Test reading metadata JSON file.""" + # Setup + meta = {'tables': {'a': {}}} + path = tmp_path / 'meta.json' + path.write_text(json.dumps(meta)) + + # Run + result = _read_metadata_json(path) + + # Assert + assert result == meta diff --git a/tests/unit/test_benchmark.py b/tests/unit/test_benchmark.py index 515e29b1..35c4bb9c 100644 --- a/tests/unit/test_benchmark.py +++ b/tests/unit/test_benchmark.py @@ -11,7 +11,6 @@ import pytest import yaml -from sdgym import benchmark_single_table from sdgym.benchmark import ( _add_adjusted_scores, _check_write_permissions, @@ -29,6 +28,8 @@ _validate_aws_inputs, _validate_output_destination, _write_metainfo_file, + benchmark_multi_table, + benchmark_single_table, benchmark_single_table_aws, ) from sdgym.result_writer import LocalResultsWriter @@ -522,18 +523,30 @@ def test__validate_output_destination_with_aws_access_key_ids(mock_validate): ) -def test__setup_output_destination(tmp_path): - """Test the `_setup_output_destination` function.""" +def test__setup_output_destination_none(): + """If output_destination is None, the function should return an empty dict.""" + # Setup + synthesizers = ['GaussianCopulaSynthesizer', 'CTGANSynthesizer'] + datasets = ['adult', 'census'] + + # Run + result = _setup_output_destination(None, synthesizers, datasets, 'single_table') + + # Assert + assert result == {} + + +def test__setup_output_destination_single_table(tmp_path): + """Test the `_setup_output_destination` function with `single_table` modality.""" # Setup output_destination = tmp_path / 'output_destination' synthesizers = ['GaussianCopulaSynthesizer', 'CTGANSynthesizer'] datasets = ['adult', 'census'] today = datetime.today().strftime('%m_%d_%Y') - base_path = output_destination / f'SDGym_results_{today}' + base_path = output_destination / 'single_table' / f'SDGym_results_{today}' # Run - result_1 = _setup_output_destination(None, synthesizers, datasets) - result_2 = _setup_output_destination(output_destination, synthesizers, datasets) + result = _setup_output_destination(output_destination, synthesizers, datasets, 'single_table') # Assert expected = { @@ -553,9 +566,41 @@ def test__setup_output_destination(tmp_path): } for dataset in datasets } + assert json.loads(json.dumps(result)) == expected + + +def test__setup_output_destination_multi_table(tmp_path): + """Test the `_setup_output_destination` function with `multi_table` modality.""" + # Setup + output_destination = tmp_path / 'output_destination' + synthesizers = ['HMASynthesizer'] + datasets = ['NBA', 'financial'] + today = datetime.today().strftime('%m_%d_%Y') + base_path = output_destination / 'multi_table' / f'SDGym_results_{today}' + + # Run + result = _setup_output_destination(output_destination, synthesizers, datasets, 'multi_table') + + # Assert + expected = { + dataset: { + synth: { + 'synthesizer': str(base_path / f'{dataset}_{today}' / synth / f'{synth}.pkl'), + 'synthetic_data': str( + base_path / f'{dataset}_{today}' / synth / f'{synth}_synthetic_data.zip' + ), + 'benchmark_result': str( + base_path / f'{dataset}_{today}' / synth / f'{synth}_benchmark_result.csv' + ), + 'metainfo': str(base_path / 'metainfo.yaml'), + 'results': str(base_path / 'results.csv'), + } + for synth in synthesizers + } + for dataset in datasets + } - assert result_1 == {} - assert json.loads(json.dumps(result_2)) == expected + assert json.loads(json.dumps(result)) == expected @patch('sdgym.benchmark.datetime') @@ -575,19 +620,21 @@ def test__write_metainfo_file(mock_datetime, tmp_path): synthesizers = ['GaussianCopulaSynthesizer', 'CTGANSynthesizer', 'RealTabFormerSynthesizer'] # Run - _write_metainfo_file(synthesizers, jobs, result_writer) + _write_metainfo_file(synthesizers, jobs, 'single_table', result_writer) # Assert - assert Path(file_name['metainfo']).exists() with open(file_name['metainfo'], 'r') as file: metainfo_data = yaml.safe_load(file) - assert metainfo_data['run_id'] == 'run_06_26_2025_0' - assert metainfo_data['starting_date'] == '06_26_2025' - assert metainfo_data['jobs'] == expected_jobs - assert metainfo_data['sdgym_version'] == version('sdgym') - assert metainfo_data['sdv_version'] == version('sdv') - assert metainfo_data['realtabformer_version'] == version('realtabformer') - assert metainfo_data['completed_date'] is None + + assert Path(file_name['metainfo']).exists() + assert metainfo_data['run_id'] == 'run_06_26_2025_0' + assert metainfo_data['starting_date'] == '06_26_2025' + assert metainfo_data['jobs'] == expected_jobs + assert metainfo_data['sdgym_version'] == version('sdgym') + assert metainfo_data['sdv_version'] == version('sdv') + assert metainfo_data['realtabformer_version'] == version('realtabformer') + assert metainfo_data['completed_date'] is None + assert metainfo_data['modality'] == 'single_table' @patch('sdgym.benchmark.datetime') @@ -794,6 +841,7 @@ def test_benchmark_single_table_aws( detailed_results_folder=None, custom_synthesizers=None, s3_client='s3_client_mock', + modality='single_table', ) mock_run_on_aws.assert_called_once_with( output_destination=output_destination, @@ -852,6 +900,7 @@ def test_benchmark_single_table_aws_synthesizers_none( detailed_results_folder=None, custom_synthesizers=None, s3_client='s3_client_mock', + modality='single_table', ) mock_run_on_aws.assert_called_once_with( output_destination=output_destination, @@ -1001,13 +1050,18 @@ def test__add_adjusted_scores_missing_fallback(): assert scores.equals(expected) +@pytest.mark.parametrize('modality', ['single_table', 'multi_table']) @patch('sdgym.benchmark.get_dataset_paths') -def test__generate_job_args_list_local_root_additional_folder(get_dataset_paths_mock, tmp_path): +def test__generate_job_args_list_local_root_additional_folder( + get_dataset_paths_mock, + tmp_path, + modality, +): """Local additional_datasets_folder should point to root/single_table.""" # Setup local_root = tmp_path / 'my_root' local_root.mkdir() - dataset_path = tmp_path / 'my_root' / 'single_table' / 'datasetA' + dataset_path = tmp_path / 'my_root' / modality / 'datasetA' get_dataset_paths_mock.return_value = [dataset_path] # Run @@ -1025,12 +1079,13 @@ def test__generate_job_args_list_local_root_additional_folder(get_dataset_paths_ synthesizers=[], custom_synthesizers=None, s3_client=None, + modality=modality, ) # Assert get_dataset_paths_mock.assert_called_once_with( - modality='single_table', - bucket=str(local_root / 'single_table'), + modality=modality, + bucket=str(local_root / modality), aws_access_key_id=None, aws_secret_access_key=None, ) @@ -1059,6 +1114,7 @@ def test__generate_job_args_list_s3_root_additional_folder(get_dataset_paths_moc synthesizers=[], custom_synthesizers=None, s3_client=None, + modality='single_table', ) # Assert @@ -1097,3 +1153,148 @@ def test_benchmark_single_table_no_warning_uniform_synthesizer(recwarn): warnings_text = ' '.join(str(w.message) for w in recwarn) assert 'is incompatible with transformer' not in warnings_text pd.testing.assert_frame_equal(result[expected_result.columns], expected_result) + + +@patch('sdgym.benchmark._update_metainfo_file') +@patch('sdgym.benchmark._write_metainfo_file') +@patch('sdgym.benchmark._run_jobs') +@patch('sdgym.benchmark._generate_job_args_list') +@patch('sdgym.benchmark._validate_inputs') +@patch('sdgym.benchmark.LocalResultsWriter') +@patch('sdgym.benchmark._validate_output_destination') +def test_benchmark_multi_table_with_jobs( + mock__validate_output_destination, + mock_LocalResultsWriter, + mock__validate_inputs, + mock__generate_job_args_list, + mock__run_jobs, + mock__write_metainfo_file, + mock__update_metainfo_file, +): + """Test that `benchmark_multi_table` runs jobs and updates metainfo when there are job args.""" + # Setup + fake_scores = pd.DataFrame({'a': [1]}) + mock__run_jobs.return_value = fake_scores + job_args = ('arg1', 'arg2', {'metainfo': 'meta.yaml'}) + mock__generate_job_args_list.return_value = [job_args] + + # Run + scores = benchmark_multi_table( + synthesizers=['HMASynthesizer'], + custom_synthesizers=['CustomSynth'], + sdv_datasets=['dataset1'], + additional_datasets_folder='extra', + limit_dataset_size=True, + compute_quality_score=True, + compute_diagnostic_score=True, + timeout=10, + output_destination='output_dir', + show_progress=True, + ) + + # Assert + mock__validate_output_destination.assert_called_once_with('output_dir') + mock_LocalResultsWriter.assert_called_once_with() + mock__validate_inputs.assert_called_once_with( + output_filepath=None, + detailed_results_folder=None, + synthesizers=['HMASynthesizer', 'MultiTableUniformSynthesizer'], + custom_synthesizers=['CustomSynth'], + ) + mock__generate_job_args_list.assert_called_once_with( + limit_dataset_size=True, + sdv_datasets=['dataset1'], + additional_datasets_folder='extra', + sdmetrics=None, + detailed_results_folder=None, + timeout=10, + output_destination='output_dir', + compute_quality_score=True, + compute_diagnostic_score=True, + compute_privacy_score=None, + synthesizers=['HMASynthesizer', 'MultiTableUniformSynthesizer'], + custom_synthesizers=['CustomSynth'], + s3_client=None, + modality='multi_table', + ) + mock__write_metainfo_file.assert_called_once() + mock__run_jobs.assert_called_once_with( + multi_processing_config=None, + job_args_list=[job_args], + show_progress=True, + result_writer=mock_LocalResultsWriter.return_value, + ) + mock__update_metainfo_file.assert_called_once_with( + 'meta.yaml', + mock_LocalResultsWriter.return_value, + ) + pd.testing.assert_frame_equal(scores, fake_scores) + + +@patch('sdgym.benchmark._get_empty_dataframe') +@patch('sdgym.benchmark._write_metainfo_file') +@patch('sdgym.benchmark._generate_job_args_list') +@patch('sdgym.benchmark._validate_inputs') +@patch('sdgym.benchmark.LocalResultsWriter') +@patch('sdgym.benchmark._validate_output_destination') +def test_benchmark_multi_table_no_jobs( + mock__validate_output_destination, + mock_LocalResultsWriter, + mock__validate_inputs, + mock__generate_job_args_list, + mock__write_metainfo_file, + mock__get_empty_dataframe, +): + """Test that benchmark_multi_table returns empty dataframe when there are no job args.""" + # Setup + empty_scores = pd.DataFrame() + mock__generate_job_args_list.return_value = [] + mock__get_empty_dataframe.return_value = empty_scores + + # Run + scores = benchmark_multi_table( + synthesizers=[], + custom_synthesizers=None, + sdv_datasets=None, + additional_datasets_folder=None, + limit_dataset_size=False, + compute_quality_score=False, + compute_diagnostic_score=True, + timeout=None, + output_destination=None, + show_progress=False, + ) + + # Assert + mock__validate_output_destination.assert_called_once_with(None) + mock_LocalResultsWriter.assert_called_once_with() + mock__validate_inputs.assert_called_once_with( + output_filepath=None, + detailed_results_folder=None, + synthesizers=['MultiTableUniformSynthesizer'], + custom_synthesizers=None, + ) + mock__generate_job_args_list.assert_called_once_with( + limit_dataset_size=False, + sdv_datasets=None, + additional_datasets_folder=None, + sdmetrics=None, + detailed_results_folder=None, + timeout=None, + output_destination=None, + compute_quality_score=False, + compute_diagnostic_score=True, + compute_privacy_score=None, + synthesizers=['MultiTableUniformSynthesizer'], + custom_synthesizers=None, + s3_client=None, + modality='multi_table', + ) + mock__get_empty_dataframe.assert_called_once_with( + compute_diagnostic_score=True, + compute_quality_score=False, + compute_privacy_score=None, + sdmetrics=None, + ) + mock__write_metainfo_file.assert_called_once() + pd.testing.assert_frame_equal(scores, empty_scores) diff --git a/tests/unit/test_result_writer.py b/tests/unit/test_result_writer.py index 5f338d44..1d4b7be8 100644 --- a/tests/unit/test_result_writer.py +++ b/tests/unit/test_result_writer.py @@ -1,6 +1,6 @@ -import pickle from unittest.mock import Mock, patch +import cloudpickle import pandas as pd import yaml @@ -57,7 +57,7 @@ def test_write_pickle(self, tmp_path): # Assert with open(file_path, 'rb') as f: - loaded_obj = pickle.load(f) + loaded_obj = cloudpickle.load(f) assert loaded_obj == obj @@ -172,7 +172,7 @@ def test_write_pickle(self, mockparse_s3_path): # Assert mockparse_s3_path.assert_called_once_with('test_object.pkl') mock_s3_client.put_object.assert_called_once_with( - Body=pickle.dumps(obj), + Body=cloudpickle.dumps(obj), Bucket='bucket_name', Key='key_prefix/test_object.pkl', )