From 8bee23f1dca86b5091ed45ce1f273b29d0d866d9 Mon Sep 17 00:00:00 2001 From: Felipe Date: Mon, 2 Dec 2024 11:28:29 -0800 Subject: [PATCH 1/2] Patch --- sdmetrics/column_pairs/statistical/inter_row_msas.py | 9 ++++++--- sdmetrics/column_pairs/statistical/statistic_msas.py | 3 ++- .../statistical/sequence_length_similarity.py | 3 ++- .../unit/column_pairs/statistical/test_inter_row_msas.py | 7 +++++-- 4 files changed, 15 insertions(+), 7 deletions(-) diff --git a/sdmetrics/column_pairs/statistical/inter_row_msas.py b/sdmetrics/column_pairs/statistical/inter_row_msas.py index 4755621d..3ebf3e75 100644 --- a/sdmetrics/column_pairs/statistical/inter_row_msas.py +++ b/sdmetrics/column_pairs/statistical/inter_row_msas.py @@ -5,11 +5,12 @@ import numpy as np import pandas as pd +from sdmetrics.column_pairs.base import ColumnPairsMetric from sdmetrics.goal import Goal from sdmetrics.single_column.statistical.kscomplement import KSComplement -class InterRowMSAS: +class InterRowMSAS(ColumnPairsMetric): """Inter-Row Multi-Sequence Aggregate Similarity (MSAS) metric. Attributes: @@ -76,7 +77,7 @@ def _calculate_differences(keys, values, n_rows_diff, data_name): num_invalid_groups = len(group_sizes[group_sizes <= n_rows_diff]) if num_invalid_groups > 0: warnings.warn( - f"n_rows_diff '{n_rows_diff}' is greater than the " + f"n_rows_diff '{n_rows_diff}' is greater or equal to the " f'size of {num_invalid_groups} sequence keys in {data_name}.' ) @@ -84,7 +85,9 @@ def diff_func(group): if len(group) <= n_rows_diff: return np.nan group = group.to_numpy() - return np.mean(group[n_rows_diff:] - group[:-n_rows_diff]) + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', message='Mean of empty slice') + return np.nanmean(group[n_rows_diff:] - group[:-n_rows_diff]) with warnings.catch_warnings(): warnings.filterwarnings('ignore', message='invalid value encountered in.*') diff --git a/sdmetrics/column_pairs/statistical/statistic_msas.py b/sdmetrics/column_pairs/statistical/statistic_msas.py index 8440618d..529f4c60 100644 --- a/sdmetrics/column_pairs/statistical/statistic_msas.py +++ b/sdmetrics/column_pairs/statistical/statistic_msas.py @@ -2,11 +2,12 @@ import pandas as pd +from sdmetrics.column_pairs.base import ColumnPairsMetric from sdmetrics.goal import Goal from sdmetrics.single_column.statistical.kscomplement import KSComplement -class StatisticMSAS: +class StatisticMSAS(ColumnPairsMetric): """Statistic Multi-Sequence Aggregate Similarity (MSAS) metric. Attributes: diff --git a/sdmetrics/single_column/statistical/sequence_length_similarity.py b/sdmetrics/single_column/statistical/sequence_length_similarity.py index 105f159b..f9fe1ae1 100644 --- a/sdmetrics/single_column/statistical/sequence_length_similarity.py +++ b/sdmetrics/single_column/statistical/sequence_length_similarity.py @@ -3,10 +3,11 @@ import pandas as pd from sdmetrics.goal import Goal +from sdmetrics.single_column.base import SingleColumnMetric from sdmetrics.single_column.statistical.kscomplement import KSComplement -class SequenceLengthSimilarity: +class SequenceLengthSimilarity(SingleColumnMetric): """Sequence Length Similarity metric. Attributes: diff --git a/tests/unit/column_pairs/statistical/test_inter_row_msas.py b/tests/unit/column_pairs/statistical/test_inter_row_msas.py index 647b6569..263a26c3 100644 --- a/tests/unit/column_pairs/statistical/test_inter_row_msas.py +++ b/tests/unit/column_pairs/statistical/test_inter_row_msas.py @@ -94,9 +94,10 @@ def test_compute_with_log_warning(self): 'There are 3 non-positive values in your data, which cannot be used with log. ' "Consider changing 'apply_log' to False for a better result." ) + assert len(warning_info) == 1 assert str(warning_info[0].message) == expected_message - assert score == 0 + assert score == 0.5 def test_compute_with_log_datetime(self): """Test it crashes for logs of datetime values.""" @@ -211,7 +212,9 @@ def test_compute_warning(self): synthetic_values = pd.Series([1, 10, 3, 7, 5, 1]) # Run and Assert - warn_msg = "n_rows_diff '10' is greater than the size of 2 sequence keys in real_data." + warn_msg = ( + "n_rows_diff '10' is greater or equal to the size of 2 sequence keys in real_data." + ) with pytest.warns(UserWarning, match=warn_msg): score = InterRowMSAS.compute( real_data=(real_keys, real_values), From 232c81c4b84abe29b259272c44125304d43b1fc9 Mon Sep 17 00:00:00 2001 From: Felipe Date: Mon, 2 Dec 2024 12:23:13 -0800 Subject: [PATCH 2/2] Add tests --- .../statistical/test_inter_row_msas.py | 35 +++++++++++++++++++ .../statistical/test_statistic_msas.py | 18 ++++++++++ .../test_sequence_length_similarity.py | 14 ++++++++ 3 files changed, 67 insertions(+) diff --git a/tests/unit/column_pairs/statistical/test_inter_row_msas.py b/tests/unit/column_pairs/statistical/test_inter_row_msas.py index 263a26c3..0a0cfd88 100644 --- a/tests/unit/column_pairs/statistical/test_inter_row_msas.py +++ b/tests/unit/column_pairs/statistical/test_inter_row_msas.py @@ -1,5 +1,6 @@ from datetime import datetime +import numpy as np import pandas as pd import pytest @@ -7,6 +8,24 @@ class TestInterRowMSAS: + def test_compute_breakdown(self): + """Test `compute_breakdown` works.""" + # Setup + real_keys = pd.Series(['id1', 'id1', 'id1', 'id2', 'id2', 'id2']) + real_values = pd.Series([1, 2, 3, 4, 5, 6]) + synthetic_keys = pd.Series(['id3', 'id3', 'id3', 'id4', 'id4', 'id4']) + synthetic_values = pd.Series([1, 10, 3, 7, 5, 1]) + + metric = InterRowMSAS() + + # Run + result = metric.compute_breakdown( + real_data=(real_keys, real_values), synthetic_data=(synthetic_keys, synthetic_values) + ) + + # Assert + assert result == {'score': 0.5} + def test_compute(self): """Test it runs.""" # Setup @@ -23,6 +42,22 @@ def test_compute(self): # Assert assert score == 0.5 + def test_compute_nans(self): + """Test it runs with nans.""" + # Setup + real_keys = pd.Series(['id1', 'id1', 'id1', 'id2', 'id2', 'id2']) + real_values = pd.Series([1, 2, np.nan, 4, 5, 8]) + synthetic_keys = pd.Series(['id3', 'id3', 'id3', 'id4', 'id4', 'id4']) + synthetic_values = pd.Series([1, 10, 4, 7, np.nan, np.nan]) + + # Run + score = InterRowMSAS.compute( + real_data=(real_keys, real_values), synthetic_data=(synthetic_keys, synthetic_values) + ) + + # Assert + assert score == 0.5 + def test_compute_identical_sequences(self): """Test it returns 1 when real and synthetic data are identical.""" # Setup diff --git a/tests/unit/column_pairs/statistical/test_statistic_msas.py b/tests/unit/column_pairs/statistical/test_statistic_msas.py index 52338844..c2d71a78 100644 --- a/tests/unit/column_pairs/statistical/test_statistic_msas.py +++ b/tests/unit/column_pairs/statistical/test_statistic_msas.py @@ -7,6 +7,24 @@ class TestStatisticMSAS: + def test_compute_breakdown(self): + """Test `compute_breakdown` works.""" + # Setup + real_keys = pd.Series(['id1', 'id1', 'id1', 'id2', 'id2', 'id2']) + real_values = pd.Series([1, 2, 3, 4, 5, 6]) + synthetic_keys = pd.Series(['id3', 'id3', 'id3', 'id4', 'id4', 'id4']) + synthetic_values = pd.Series([1, 10, 3, 7, 5, 1]) + + metric = StatisticMSAS() + + # Run + result = metric.compute_breakdown( + real_data=(real_keys, real_values), synthetic_data=(synthetic_keys, synthetic_values) + ) + + # Assert + assert result == {'score': 0.5} + def test_compute_identical_sequences(self, recwarn): """Test it returns 1 when real and synthetic data are identical.""" # Setup diff --git a/tests/unit/single_column/statistical/test_sequence_length_similarity.py b/tests/unit/single_column/statistical/test_sequence_length_similarity.py index 4e27ab98..52b661e9 100644 --- a/tests/unit/single_column/statistical/test_sequence_length_similarity.py +++ b/tests/unit/single_column/statistical/test_sequence_length_similarity.py @@ -4,6 +4,20 @@ class TestSequenceLengthSimilarity: + def test_compute_breakdown(self): + """Test `compute_breakdown` works.""" + # Setup + real_data = pd.Series([1, 1, 2, 2, 2]) + synthetic_data = pd.Series([3, 4, 5, 6, 6]) + + metric = SequenceLengthSimilarity() + + # Run + result = metric.compute_breakdown(real_data, synthetic_data) + + # Assert + assert result == {'score': 0.25} + def test_compute(self): """Test it runs.""" # Setup