Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 2 additions & 8 deletions .github/workflows/integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,10 @@ jobs:
run: |
brew install libomp
echo 'export DYLD_LIBRARY_PATH=$(brew --prefix libomp)/lib:$DYLD_LIBRARY_PATH' >> $GITHUB_ENV
- if: matrix.python-version != 3.13
name: Install dependencies
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install invoke .[pomegranate,xgboost,test]
- if: matrix.python-version == 3.13
name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install invoke .[xgboost,test]
python -m pip install invoke .[test]
- name: Run integration tests
run: invoke integration

Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/minimum.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,6 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install invoke .[test,xgboost]
python -m pip install invoke .[test]
- name: Test with minimum versions
run: invoke minimum
2 changes: 1 addition & 1 deletion .github/workflows/unit.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install invoke .[test,xgboost]
python -m pip install invoke .[test]
- name: Run unit tests
run: invoke unit

Expand Down
10 changes: 5 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ license = { text = 'MIT license' }
requires-python = ">=3.8,<3.14"
readme = 'README.md'
dependencies = [
"numpy>=1.21.0;python_version<'3.10'",
"numpy>=1.23.3;python_version>='3.10' and python_version<'3.12'",
"numpy>=1.22.2;python_version<'3.10'",
"numpy>=1.24.0;python_version>='3.10' and python_version<'3.12'",
"numpy>=1.26.0;python_version>='3.12' and python_version<'3.13'",
"numpy>=2.1.0;python_version>='3.13'",
"pandas>=1.4.0;python_version<'3.11'",
Expand Down Expand Up @@ -60,10 +60,10 @@ torch = [
"torch>=2.2.0;python_version>='3.12' and python_version<'3.13'",
"torch>=2.6.0;python_version>='3.13'",
]
pomegranate = ['pomegranate>=0.15,<1.0']
pomegranate = ['pomegranate>=1.1.2,<2.0']
xgboost = ['xgboost>=2.1.3']
test = [
'sdmetrics[torch]',
'sdmetrics[pomegranate,torch,xgboost]',
'pytest>=6.2.5,<7',
'pytest-cov>=2.6.0,<3',
'pytest-rerunfailures>=10.3,<15',
Expand All @@ -74,7 +74,7 @@ test = [
'pytest-runner>=2.11.1',
]
dev = [
'sdmetrics[test, xgboost, torch]',
'sdmetrics[test, torch]',

# general
'build>=1.0.0,<2',
Expand Down
32 changes: 19 additions & 13 deletions sdmetrics/single_table/bayesian_network.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
"""BayesianNetwork based metrics for single table."""

import json
import logging

import numpy as np
import pandas as pd
import torch

from sdmetrics.goal import Goal
from sdmetrics.single_table.base import SingleTableMetric
Expand All @@ -17,11 +18,10 @@ class BNLikelihoodBase(SingleTableMetric):
@classmethod
def _likelihoods(cls, real_data, synthetic_data, metadata=None, structure=None):
try:
from pomegranate import BayesianNetwork
from pomegranate.bayesian_network import BayesianNetwork
except ImportError:
raise ImportError(
'Please install pomegranate with `pip install sdmetrics[pomegranate]`. '
'Python 3.13 is not supported.'
'Please install pomegranate with `pip install sdmetrics[pomegranate]`.'
)

real_data, synthetic_data, metadata = cls._validate_inputs(
Expand All @@ -34,19 +34,25 @@ def _likelihoods(cls, real_data, synthetic_data, metadata=None, structure=None):
return np.full(len(real_data), np.nan)

LOGGER.debug('Fitting the BayesianNetwork to the real data')
if structure:
if isinstance(structure, dict):
structure = BayesianNetwork.from_json(json.dumps(structure)).structure

bn = BayesianNetwork.from_structure(real_data[fields].to_numpy(), structure)
else:
bn = BayesianNetwork.from_samples(real_data[fields].to_numpy(), algorithm='chow-liu')

bn = BayesianNetwork(structure=structure if structure else None, algorithm='chow-liu')
category_to_integer = {
column: {
category: i
for i, category in enumerate(
pd.concat([real_data[column], synthetic_data[column]]).unique()
)
}
for column in fields
}
real_data[fields] = real_data[fields].replace(category_to_integer).astype('int64')
synthetic_data[fields] = synthetic_data[fields].replace(category_to_integer).astype('int64')

bn.fit(torch.tensor(real_data[fields].to_numpy()))
LOGGER.debug('Evaluating likelihood of the synthetic data')
probabilities = []
for _, row in synthetic_data[fields].iterrows():
try:
probabilities.append(bn.probability([row.to_numpy()]))
probabilities.append(bn.probability([row.to_numpy()]).item())
except ValueError:
probabilities.append(0)

Expand Down
92 changes: 73 additions & 19 deletions tests/unit/single_table/test_bayesian_network.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,97 @@
import sys
from unittest.mock import Mock
import re
from unittest.mock import Mock, patch

import numpy as np
import pandas as pd
import pytest

from sdmetrics.single_table import BNLikelihood, BNLogLikelihood


@pytest.fixture
def bad_pomegranate():
old_pomegranate = getattr(sys.modules, 'pomegranate', None)
sys.modules['pomegranate'] = pytest
yield
if old_pomegranate is not None:
sys.modules['pomegranate'] = old_pomegranate
else:
del sys.modules['pomegranate']
def real_data():
return pd.DataFrame({
'a': ['a', 'b', 'a', 'b', 'a', 'b'],
'b': ['c', 'd', 'c', 'd', 'c', 'd'],
'c': [True, False, True, False, True, False],
'd': [1, 2, 3, 4, 5, 6],
'e': [10, 2, 3, 4, 5, 6],
})


@pytest.fixture
def synthetic_data():
return pd.DataFrame({
'a': ['a', 'b', 'b', 'b', 'a', 'b'],
'b': ['d', 'd', 'c', 'd', 'c', 'd'],
'c': [False, False, True, False, True, False],
'd': [4, 2, 3, 4, 5, 6],
'e': [12, 2, 3, 4, 5, 6],
})


@pytest.fixture
def metadata():
return {
'columns': {
'a': {'sdtype': 'categorical'},
'b': {'sdtype': 'categorical'},
'c': {'sdtype': 'boolean'},
'd': {'sdtype': 'categorical'},
'e': {'sdtype': 'numerical'},
}
}


class TestBNLikelihood:
def test_compute(self, bad_pomegranate):
"""Test that an ``ImportError`` is raised."""
@patch.dict('sys.modules', {'pomegranate.bayesian_network': None})
def test_compute_error(self):
"""Test that an `ImportError` is raised."""
# Setup
metric = BNLikelihood()

# Act and Assert
expected_message = r'Please install pomegranate with `pip install sdmetrics\[pomegranate\]`\. Python 3\.13 is not supported\.' # noqa: E501
# Run and Assert
expected_message = re.escape(
'Please install pomegranate with `pip install sdmetrics[pomegranate]`.'
)
with pytest.raises(ImportError, match=expected_message):
metric.compute(Mock(), Mock())

def test_compute(self, real_data, synthetic_data, metadata):
"""Test the ``compute``method."""
# Setup
np.random.seed(42)
metric = BNLikelihood()

# Run
result = metric.compute(real_data, synthetic_data, metadata)

# Assert
assert np.isclose(result, 0.1111111044883728, atol=1e-5)
Comment on lines +66 to +70
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Those tests ensure the metric computation remains consistent while upgrading pomegranate.



class TestBNLogLikelihood:
def test_compute(self, bad_pomegranate):
"""Test that an ``ImportError`` is raised."""
@patch.dict('sys.modules', {'pomegranate.bayesian_network': None})
def test_compute_error(self):
"""Test that an `ImportError` is raised."""
# Setup
metric = BNLogLikelihood()

# Act and Assert
expected_message = expected_message = (
r'Please install pomegranate with `pip install sdmetrics\[pomegranate\]`\. Python 3\.13 is not supported\.' # noqa: E501
# Run and Assert
expected_message = re.escape(
'Please install pomegranate with `pip install sdmetrics[pomegranate]`.'
)
with pytest.raises(ImportError, match=expected_message):
metric.compute(Mock(), Mock())

def test_compute(self, real_data, synthetic_data, metadata):
"""Test the ``compute``method."""
# Setup
np.random.seed(42)
metric = BNLogLikelihood()

# Run
result = metric.compute(real_data, synthetic_data, metadata)

# Assert
assert np.isclose(result, -7.334733486175537, atol=1e-5)
Loading