Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions sdmetrics/timeseries/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from sdmetrics.timeseries.efficacy.classification import LSTMClassifierEfficacy
from sdmetrics.timeseries.inter_row_msas import InterRowMSAS
from sdmetrics.timeseries.sequence_length_similarity import SequenceLengthSimilarity
from sdmetrics.timeseries.statistic_msas import StatisticMSAS

__all__ = [
'base',
Expand All @@ -20,4 +21,5 @@
'LSTMClassifierEfficacy',
'InterRowMSAS',
'SequenceLengthSimilarity',
'StatisticMSAS',
]
96 changes: 96 additions & 0 deletions sdmetrics/timeseries/statistic_msas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
"""StatisticMSAS module."""

import numpy as np
import pandas as pd

from sdmetrics.goal import Goal
from sdmetrics.single_column.statistical.kscomplement import KSComplement


class StatisticMSAS:
"""Statistic Multi-Sequence Aggregate Similarity (MSAS) metric.

Attributes:
name (str):
Name to use when reports about this metric are printed.
goal (sdmetrics.goal.Goal):
The goal of this metric.
min_value (Union[float, tuple[float]]):
Minimum value or values that this metric can take.
max_value (Union[float, tuple[float]]):
Maximum value or values that this metric can take.
"""

name = 'Statistic Multi-Sequence Aggregate Similarity'
goal = Goal.MAXIMIZE
min_value = 0.0
max_value = 1.0

@staticmethod
def compute(real_data, synthetic_data, statistic='mean'):
"""Compute this metric.

This metric compares the distribution of a given statistic across sequences
in the real data vs. the synthetic data.

It works as follows:
- Calculate the specified statistic for each sequence in the real data
- Form a distribution D_r from these statistics
- Do the same for the synthetic data to form a new distribution D_s
- Apply the KSComplement metric to compare the similarities of (D_r, D_s)
- Return this score

Args:
real_data (tuple[pd.Series, pd.Series]):
A tuple of 2 pandas.Series objects. The first represents the sequence key
of the real data and the second represents a continuous column of data.
synthetic_data (tuple[pd.Series, pd.Series]):
A tuple of 2 pandas.Series objects. The first represents the sequence key
of the synthetic data and the second represents a continuous column of data.
statistic (str):
A string representing the statistic function to use when computing MSAS.

Available options are:
- 'mean': The arithmetic mean of the sequence
- 'median': The median value of the sequence
- 'std': The standard deviation of the sequence
- 'min': The minimum value in the sequence
- 'max': The maximum value in the sequence

Returns:
float:
The similarity score between the real and synthetic data distributions.
"""
statistic_functions = {
'mean': np.mean,
'median': np.median,
'std': np.std,
'min': np.min,
'max': np.max,
}
if statistic not in statistic_functions:
raise ValueError(
f'Invalid statistic: {statistic}.'
f' Choose from [{", ".join(statistic_functions.keys())}].'
)

for data in [real_data, synthetic_data]:
if (
not isinstance(data, tuple)
or len(data) != 2
or (not (isinstance(data[0], pd.Series) and isinstance(data[1], pd.Series)))
):
raise ValueError('The data must be a tuple of two pandas series.')

real_keys, real_values = real_data
synthetic_keys, synthetic_values = synthetic_data
stat_func = statistic_functions[statistic]

def calculate_statistics(keys, values):
df = pd.DataFrame({'keys': keys, 'values': values})
return df.groupby('keys')['values'].agg(stat_func)

real_stats = calculate_statistics(real_keys, real_values)
synthetic_stats = calculate_statistics(synthetic_keys, synthetic_values)

return KSComplement.compute(real_stats, synthetic_stats)
125 changes: 125 additions & 0 deletions tests/unit/timeseries/test_statistic_msas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import re

import pandas as pd
import pytest

from sdmetrics.timeseries import StatisticMSAS


class TestStatisticMSAS:
def test_compute_identical_sequences(self):
"""Test it returns 1 when real and synthetic data are identical."""
# Setup
real_keys = pd.Series(['id1', 'id1', 'id1', 'id2', 'id2', 'id2'])
real_values = pd.Series([1, 2, 3, 4, 5, 6])
synthetic_keys = pd.Series(['id3', 'id3', 'id3', 'id4', 'id4', 'id4'])
synthetic_values = pd.Series([1, 2, 3, 4, 5, 6])

# Run and Assert
for statistic in ['mean', 'median', 'std', 'min', 'max']:
score = StatisticMSAS.compute(
real_data=(real_keys, real_values),
synthetic_data=(synthetic_keys, synthetic_values),
statistic=statistic,
)
assert score == 1

def test_compute_different_sequences(self):
"""Test it for distinct distributions."""
# Setup
real_keys = pd.Series(['id1', 'id1', 'id1', 'id2', 'id2', 'id2'])
real_values = pd.Series([1, 2, 3, 4, 5, 6])
synthetic_keys = pd.Series(['id3', 'id3', 'id3', 'id4', 'id4', 'id4'])
synthetic_values = pd.Series([10, 20, 30, 40, 50, 60])

# Run and Assert
for statistic in ['mean', 'median', 'std', 'min', 'max']:
score = StatisticMSAS.compute(
real_data=(real_keys, real_values),
synthetic_data=(synthetic_keys, synthetic_values),
statistic=statistic,
)
assert score == 0

def test_compute_with_single_sequence(self):
"""Test it with a single sequence."""
# Setup
real_keys = pd.Series(['id1', 'id1', 'id1'])
real_values = pd.Series([1, 2, 3])
synthetic_keys = pd.Series(['id2', 'id2', 'id2'])
synthetic_values = pd.Series([1, 2, 3])

# Run
score = StatisticMSAS.compute(
real_data=(real_keys, real_values),
synthetic_data=(synthetic_keys, synthetic_values),
statistic='mean',
)

# Assert
assert score == 1

def test_compute_with_different_sequence_lengths(self):
"""Test it with different sequence lengths."""
# Setup
real_keys = pd.Series(['id1', 'id1', 'id1', 'id2', 'id2'])
real_values = pd.Series([1, 2, 3, 4, 5])
synthetic_keys = pd.Series(['id2', 'id2', 'id3', 'id4', 'id5'])
synthetic_values = pd.Series([1, 2, 3, 4, 5])

# Run
score = StatisticMSAS.compute(
real_data=(real_keys, real_values),
synthetic_data=(synthetic_keys, synthetic_values),
statistic='mean',
)

# Assert
assert score == 0.75

def test_compute_with_invalid_statistic(self):
"""Test it raises ValueError for invalid statistic."""
# Setup
real_keys = pd.Series(['id1', 'id1', 'id1'])
real_values = pd.Series([1, 2, 3])
synthetic_keys = pd.Series(['id2', 'id2', 'id2'])
synthetic_values = pd.Series([1, 2, 3])

# Run and Assert
err_msg = re.escape(
'Invalid statistic: invalid. Choose from [mean, median, std, min, max].'
)
with pytest.raises(ValueError, match=err_msg):
StatisticMSAS.compute(
real_data=(real_keys, real_values),
synthetic_data=(synthetic_keys, synthetic_values),
statistic='invalid',
)

def test_compute_invalid_real_data(self):
"""Test that it raises ValueError when real_data is invalid."""
# Setup
real_data = [[1, 2, 3], [4, 5, 6]] # Not a tuple of pandas Series
synthetic_keys = pd.Series(['id1', 'id1', 'id2', 'id2'])
synthetic_values = pd.Series([1, 2, 3, 4])

# Run and Assert
with pytest.raises(ValueError, match='The data must be a tuple of two pandas series.'):
StatisticMSAS.compute(
real_data=real_data,
synthetic_data=(synthetic_keys, synthetic_values),
)

def test_compute_invalid_synthetic_data(self):
"""Test that it raises ValueError when synthetic_data is invalid."""
# Setup
real_keys = pd.Series(['id1', 'id1', 'id2', 'id2'])
real_values = pd.Series([1, 2, 3, 4])
synthetic_data = [[1, 2, 3], [4, 5, 6]] # Not a tuple of pandas Series

# Run and Assert
with pytest.raises(ValueError, match='The data must be a tuple of two pandas series.'):
StatisticMSAS.compute(
real_data=(real_keys, real_values),
synthetic_data=synthetic_data,
)
Loading