Skip to content

Commit 6bec051

Browse files
authored
Add metric for general MSAS statistics (#649)
1 parent dd93b1a commit 6bec051

File tree

3 files changed

+223
-0
lines changed

3 files changed

+223
-0
lines changed

sdmetrics/timeseries/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from sdmetrics.timeseries.efficacy.classification import LSTMClassifierEfficacy
88
from sdmetrics.timeseries.inter_row_msas import InterRowMSAS
99
from sdmetrics.timeseries.sequence_length_similarity import SequenceLengthSimilarity
10+
from sdmetrics.timeseries.statistic_msas import StatisticMSAS
1011

1112
__all__ = [
1213
'base',
@@ -20,4 +21,5 @@
2021
'LSTMClassifierEfficacy',
2122
'InterRowMSAS',
2223
'SequenceLengthSimilarity',
24+
'StatisticMSAS',
2325
]
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
"""StatisticMSAS module."""
2+
3+
import numpy as np
4+
import pandas as pd
5+
6+
from sdmetrics.goal import Goal
7+
from sdmetrics.single_column.statistical.kscomplement import KSComplement
8+
9+
10+
class StatisticMSAS:
11+
"""Statistic Multi-Sequence Aggregate Similarity (MSAS) metric.
12+
13+
Attributes:
14+
name (str):
15+
Name to use when reports about this metric are printed.
16+
goal (sdmetrics.goal.Goal):
17+
The goal of this metric.
18+
min_value (Union[float, tuple[float]]):
19+
Minimum value or values that this metric can take.
20+
max_value (Union[float, tuple[float]]):
21+
Maximum value or values that this metric can take.
22+
"""
23+
24+
name = 'Statistic Multi-Sequence Aggregate Similarity'
25+
goal = Goal.MAXIMIZE
26+
min_value = 0.0
27+
max_value = 1.0
28+
29+
@staticmethod
30+
def compute(real_data, synthetic_data, statistic='mean'):
31+
"""Compute this metric.
32+
33+
This metric compares the distribution of a given statistic across sequences
34+
in the real data vs. the synthetic data.
35+
36+
It works as follows:
37+
- Calculate the specified statistic for each sequence in the real data
38+
- Form a distribution D_r from these statistics
39+
- Do the same for the synthetic data to form a new distribution D_s
40+
- Apply the KSComplement metric to compare the similarities of (D_r, D_s)
41+
- Return this score
42+
43+
Args:
44+
real_data (tuple[pd.Series, pd.Series]):
45+
A tuple of 2 pandas.Series objects. The first represents the sequence key
46+
of the real data and the second represents a continuous column of data.
47+
synthetic_data (tuple[pd.Series, pd.Series]):
48+
A tuple of 2 pandas.Series objects. The first represents the sequence key
49+
of the synthetic data and the second represents a continuous column of data.
50+
statistic (str):
51+
A string representing the statistic function to use when computing MSAS.
52+
53+
Available options are:
54+
- 'mean': The arithmetic mean of the sequence
55+
- 'median': The median value of the sequence
56+
- 'std': The standard deviation of the sequence
57+
- 'min': The minimum value in the sequence
58+
- 'max': The maximum value in the sequence
59+
60+
Returns:
61+
float:
62+
The similarity score between the real and synthetic data distributions.
63+
"""
64+
statistic_functions = {
65+
'mean': np.mean,
66+
'median': np.median,
67+
'std': np.std,
68+
'min': np.min,
69+
'max': np.max,
70+
}
71+
if statistic not in statistic_functions:
72+
raise ValueError(
73+
f'Invalid statistic: {statistic}.'
74+
f' Choose from [{", ".join(statistic_functions.keys())}].'
75+
)
76+
77+
for data in [real_data, synthetic_data]:
78+
if (
79+
not isinstance(data, tuple)
80+
or len(data) != 2
81+
or (not (isinstance(data[0], pd.Series) and isinstance(data[1], pd.Series)))
82+
):
83+
raise ValueError('The data must be a tuple of two pandas series.')
84+
85+
real_keys, real_values = real_data
86+
synthetic_keys, synthetic_values = synthetic_data
87+
stat_func = statistic_functions[statistic]
88+
89+
def calculate_statistics(keys, values):
90+
df = pd.DataFrame({'keys': keys, 'values': values})
91+
return df.groupby('keys')['values'].agg(stat_func)
92+
93+
real_stats = calculate_statistics(real_keys, real_values)
94+
synthetic_stats = calculate_statistics(synthetic_keys, synthetic_values)
95+
96+
return KSComplement.compute(real_stats, synthetic_stats)
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
import re
2+
3+
import pandas as pd
4+
import pytest
5+
6+
from sdmetrics.timeseries import StatisticMSAS
7+
8+
9+
class TestStatisticMSAS:
10+
def test_compute_identical_sequences(self):
11+
"""Test it returns 1 when real and synthetic data are identical."""
12+
# Setup
13+
real_keys = pd.Series(['id1', 'id1', 'id1', 'id2', 'id2', 'id2'])
14+
real_values = pd.Series([1, 2, 3, 4, 5, 6])
15+
synthetic_keys = pd.Series(['id3', 'id3', 'id3', 'id4', 'id4', 'id4'])
16+
synthetic_values = pd.Series([1, 2, 3, 4, 5, 6])
17+
18+
# Run and Assert
19+
for statistic in ['mean', 'median', 'std', 'min', 'max']:
20+
score = StatisticMSAS.compute(
21+
real_data=(real_keys, real_values),
22+
synthetic_data=(synthetic_keys, synthetic_values),
23+
statistic=statistic,
24+
)
25+
assert score == 1
26+
27+
def test_compute_different_sequences(self):
28+
"""Test it for distinct distributions."""
29+
# Setup
30+
real_keys = pd.Series(['id1', 'id1', 'id1', 'id2', 'id2', 'id2'])
31+
real_values = pd.Series([1, 2, 3, 4, 5, 6])
32+
synthetic_keys = pd.Series(['id3', 'id3', 'id3', 'id4', 'id4', 'id4'])
33+
synthetic_values = pd.Series([10, 20, 30, 40, 50, 60])
34+
35+
# Run and Assert
36+
for statistic in ['mean', 'median', 'std', 'min', 'max']:
37+
score = StatisticMSAS.compute(
38+
real_data=(real_keys, real_values),
39+
synthetic_data=(synthetic_keys, synthetic_values),
40+
statistic=statistic,
41+
)
42+
assert score == 0
43+
44+
def test_compute_with_single_sequence(self):
45+
"""Test it with a single sequence."""
46+
# Setup
47+
real_keys = pd.Series(['id1', 'id1', 'id1'])
48+
real_values = pd.Series([1, 2, 3])
49+
synthetic_keys = pd.Series(['id2', 'id2', 'id2'])
50+
synthetic_values = pd.Series([1, 2, 3])
51+
52+
# Run
53+
score = StatisticMSAS.compute(
54+
real_data=(real_keys, real_values),
55+
synthetic_data=(synthetic_keys, synthetic_values),
56+
statistic='mean',
57+
)
58+
59+
# Assert
60+
assert score == 1
61+
62+
def test_compute_with_different_sequence_lengths(self):
63+
"""Test it with different sequence lengths."""
64+
# Setup
65+
real_keys = pd.Series(['id1', 'id1', 'id1', 'id2', 'id2'])
66+
real_values = pd.Series([1, 2, 3, 4, 5])
67+
synthetic_keys = pd.Series(['id2', 'id2', 'id3', 'id4', 'id5'])
68+
synthetic_values = pd.Series([1, 2, 3, 4, 5])
69+
70+
# Run
71+
score = StatisticMSAS.compute(
72+
real_data=(real_keys, real_values),
73+
synthetic_data=(synthetic_keys, synthetic_values),
74+
statistic='mean',
75+
)
76+
77+
# Assert
78+
assert score == 0.75
79+
80+
def test_compute_with_invalid_statistic(self):
81+
"""Test it raises ValueError for invalid statistic."""
82+
# Setup
83+
real_keys = pd.Series(['id1', 'id1', 'id1'])
84+
real_values = pd.Series([1, 2, 3])
85+
synthetic_keys = pd.Series(['id2', 'id2', 'id2'])
86+
synthetic_values = pd.Series([1, 2, 3])
87+
88+
# Run and Assert
89+
err_msg = re.escape(
90+
'Invalid statistic: invalid. Choose from [mean, median, std, min, max].'
91+
)
92+
with pytest.raises(ValueError, match=err_msg):
93+
StatisticMSAS.compute(
94+
real_data=(real_keys, real_values),
95+
synthetic_data=(synthetic_keys, synthetic_values),
96+
statistic='invalid',
97+
)
98+
99+
def test_compute_invalid_real_data(self):
100+
"""Test that it raises ValueError when real_data is invalid."""
101+
# Setup
102+
real_data = [[1, 2, 3], [4, 5, 6]] # Not a tuple of pandas Series
103+
synthetic_keys = pd.Series(['id1', 'id1', 'id2', 'id2'])
104+
synthetic_values = pd.Series([1, 2, 3, 4])
105+
106+
# Run and Assert
107+
with pytest.raises(ValueError, match='The data must be a tuple of two pandas series.'):
108+
StatisticMSAS.compute(
109+
real_data=real_data,
110+
synthetic_data=(synthetic_keys, synthetic_values),
111+
)
112+
113+
def test_compute_invalid_synthetic_data(self):
114+
"""Test that it raises ValueError when synthetic_data is invalid."""
115+
# Setup
116+
real_keys = pd.Series(['id1', 'id1', 'id2', 'id2'])
117+
real_values = pd.Series([1, 2, 3, 4])
118+
synthetic_data = [[1, 2, 3], [4, 5, 6]] # Not a tuple of pandas Series
119+
120+
# Run and Assert
121+
with pytest.raises(ValueError, match='The data must be a tuple of two pandas series.'):
122+
StatisticMSAS.compute(
123+
real_data=(real_keys, real_values),
124+
synthetic_data=synthetic_data,
125+
)

0 commit comments

Comments
 (0)