diff --git a/pyproject.toml b/pyproject.toml index 38ba48b0..20817bae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -186,6 +186,7 @@ exclude = [ ".git", "__pycache__", ".ipynb_checkpoints", + ".ipynb", "tasks.py", ] diff --git a/sdmetrics/column_pairs/statistical/contingency_similarity.py b/sdmetrics/column_pairs/statistical/contingency_similarity.py index c093216e..5d2c801d 100644 --- a/sdmetrics/column_pairs/statistical/contingency_similarity.py +++ b/sdmetrics/column_pairs/statistical/contingency_similarity.py @@ -1,7 +1,5 @@ """Contingency Similarity Metric.""" -import pandas as pd - from sdmetrics.column_pairs.base import ColumnPairsMetric from sdmetrics.goal import Goal @@ -42,43 +40,16 @@ def compute(cls, real_data, synthetic_data): columns = real_data.columns[:2] real = real_data[columns] synthetic = synthetic_data[columns] - contingency_real = pd.crosstab( - index=real[columns[0]].astype(str), - columns=real[columns[1]].astype(str), - normalize=True, - ) - contingency_synthetic = pd.crosstab( - index=synthetic[columns[0]].astype(str), - columns=synthetic[columns[1]].astype(str), - normalize=True, - ) - - real_append_cols = {} - for col in set(contingency_synthetic.columns) - set(contingency_real.columns): - real_append_cols[col] = [0 for _ in range(len(contingency_real))] - contingency_real = pd.concat( - (contingency_real, pd.DataFrame(real_append_cols, index=contingency_real.index)), - axis=1, + contingency_real = real.groupby(list(columns), dropna=False).size() / len(real) + contingency_synthetic = synthetic.groupby(list(columns), dropna=False).size() / len( + synthetic ) - - synthetic_append_cols = {} - for col in set(contingency_real.columns) - set(contingency_synthetic.columns): - synthetic_append_cols[col] = [0 for _ in range(len(contingency_synthetic))] - contingency_synthetic = pd.concat( - ( - contingency_synthetic, - pd.DataFrame(synthetic_append_cols, index=contingency_synthetic.index), - ), - axis=1, - ) - - for row in set(contingency_synthetic.index) - set(contingency_real.index): - contingency_real.loc[row] = [0 for _ in range(len(contingency_real.columns))] - for row in set(contingency_real.index) - set(contingency_synthetic.index): - contingency_synthetic.loc[row] = [0 for _ in range(len(contingency_synthetic.columns))] - - variation = abs(contingency_real - contingency_synthetic) / 2 - return 1 - variation.sum().sum() + combined_index = contingency_real.index.union(contingency_synthetic.index) + contingency_synthetic = contingency_synthetic.reindex(combined_index, fill_value=0) + contingency_real = contingency_real.reindex(combined_index, fill_value=0) + diff = abs(contingency_real - contingency_synthetic).fillna(0) + variation = diff / 2 + return 1 - variation.sum() @classmethod def normalize(cls, raw_score): diff --git a/tests/integration/reports/multi_table/_properties/test_inter_table_trends.py b/tests/integration/reports/multi_table/_properties/test_inter_table_trends.py index 5785e7a5..cd1fba2e 100644 --- a/tests/integration/reports/multi_table/_properties/test_inter_table_trends.py +++ b/tests/integration/reports/multi_table/_properties/test_inter_table_trends.py @@ -17,7 +17,7 @@ def test_end_to_end(self): result = inter_table_trends.get_score(real_data, synthetic_data, metadata) # Assert - assert result == 0.4416666666666667 + assert result == 0.4416666666666666 def test_with_progress_bar(self): """Test that the progress bar is correctly updated.""" @@ -38,5 +38,5 @@ def test_with_progress_bar(self): result = inter_table_trends.get_score(real_data, synthetic_data, metadata, progress_bar) # Assert - assert result == 0.4416666666666667 + assert result == 0.4416666666666666 assert mock_update.call_count == num_iter diff --git a/tests/integration/reports/multi_table/test_quality_report.py b/tests/integration/reports/multi_table/test_quality_report.py index 07593d9f..9c3b79b1 100644 --- a/tests/integration/reports/multi_table/test_quality_report.py +++ b/tests/integration/reports/multi_table/test_quality_report.py @@ -232,9 +232,9 @@ def test_quality_report_end_to_end(): # Assert expected_properties = pd.DataFrame({ 'Property': ['Column Shapes', 'Column Pair Trends', 'Cardinality', 'Intertable Trends'], - 'Score': [0.7978174603174604, 0.45654629583521095, 0.95, 0.4416666666666667], + 'Score': [0.7978174603174604, 0.45654629583521095, 0.95, 0.4416666666666666], }) - assert score == 0.6615076057048345 + assert score == 0.6615076057048344 pd.testing.assert_frame_equal(properties, expected_properties) expected_info_keys = { 'report_type', @@ -272,9 +272,9 @@ def test_quality_report_with_object_datetimes(): # Assert expected_properties = pd.DataFrame({ 'Property': ['Column Shapes', 'Column Pair Trends', 'Cardinality', 'Intertable Trends'], - 'Score': [0.7978174603174604, 0.45654629583521095, 0.95, 0.4416666666666667], + 'Score': [0.7978174603174604, 0.45654629583521095, 0.95, 0.4416666666666666], }) - assert score == 0.6615076057048345 + assert score == 0.6615076057048344 pd.testing.assert_frame_equal(properties, expected_properties) @@ -342,7 +342,7 @@ def test_quality_report_with_errors(): None, ], }) - assert score == 0.7249603174603175 + assert score == 0.7249603174603174 pd.testing.assert_frame_equal(properties, expected_properties) pd.testing.assert_frame_equal(details_column_shapes, expected_details)