diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index e3f506d0..6f8f772b 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -32,16 +32,10 @@ jobs: run: | brew install libomp echo 'export DYLD_LIBRARY_PATH=$(brew --prefix libomp)/lib:$DYLD_LIBRARY_PATH' >> $GITHUB_ENV - - if: matrix.python-version != 3.13 - name: Install dependencies + - name: Install dependencies run: | python -m pip install --upgrade pip - python -m pip install invoke .[pomegranate,xgboost,test] - - if: matrix.python-version == 3.13 - name: Install dependencies - run: | - python -m pip install --upgrade pip - python -m pip install invoke .[xgboost,test] + python -m pip install invoke .[test] - name: Run integration tests run: invoke integration diff --git a/.github/workflows/minimum.yml b/.github/workflows/minimum.yml index c65e9e6d..8a666061 100644 --- a/.github/workflows/minimum.yml +++ b/.github/workflows/minimum.yml @@ -35,6 +35,6 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - python -m pip install invoke .[test,xgboost] + python -m pip install invoke .[test] - name: Test with minimum versions run: invoke minimum diff --git a/.github/workflows/unit.yml b/.github/workflows/unit.yml index b043f6d1..136474f4 100644 --- a/.github/workflows/unit.yml +++ b/.github/workflows/unit.yml @@ -35,7 +35,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - python -m pip install invoke .[test,xgboost] + python -m pip install invoke .[test] - name: Run unit tests run: invoke unit diff --git a/pyproject.toml b/pyproject.toml index b1696da8..ed47ede8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,8 +21,8 @@ license = { text = 'MIT license' } requires-python = ">=3.8,<3.14" readme = 'README.md' dependencies = [ - "numpy>=1.21.0;python_version<'3.10'", - "numpy>=1.23.3;python_version>='3.10' and python_version<'3.12'", + "numpy>=1.22.2;python_version<'3.10'", + "numpy>=1.24.0;python_version>='3.10' and python_version<'3.12'", "numpy>=1.26.0;python_version>='3.12' and python_version<'3.13'", "numpy>=2.1.0;python_version>='3.13'", "pandas>=1.4.0;python_version<'3.11'", @@ -60,10 +60,10 @@ torch = [ "torch>=2.2.0;python_version>='3.12' and python_version<'3.13'", "torch>=2.6.0;python_version>='3.13'", ] -pomegranate = ['pomegranate>=0.15,<1.0'] +pomegranate = ['pomegranate>=1.1.2,<2.0'] xgboost = ['xgboost>=2.1.3'] test = [ - 'sdmetrics[torch]', + 'sdmetrics[pomegranate,torch,xgboost]', 'pytest>=6.2.5,<7', 'pytest-cov>=2.6.0,<3', 'pytest-rerunfailures>=10.3,<15', @@ -74,7 +74,7 @@ test = [ 'pytest-runner>=2.11.1', ] dev = [ - 'sdmetrics[test, xgboost, torch]', + 'sdmetrics[test, torch]', # general 'build>=1.0.0,<2', diff --git a/sdmetrics/single_table/bayesian_network.py b/sdmetrics/single_table/bayesian_network.py index 069127a7..b5d69f52 100644 --- a/sdmetrics/single_table/bayesian_network.py +++ b/sdmetrics/single_table/bayesian_network.py @@ -1,9 +1,10 @@ """BayesianNetwork based metrics for single table.""" -import json import logging import numpy as np +import pandas as pd +import torch from sdmetrics.goal import Goal from sdmetrics.single_table.base import SingleTableMetric @@ -17,11 +18,10 @@ class BNLikelihoodBase(SingleTableMetric): @classmethod def _likelihoods(cls, real_data, synthetic_data, metadata=None, structure=None): try: - from pomegranate import BayesianNetwork + from pomegranate.bayesian_network import BayesianNetwork except ImportError: raise ImportError( - 'Please install pomegranate with `pip install sdmetrics[pomegranate]`. ' - 'Python 3.13 is not supported.' + 'Please install pomegranate with `pip install sdmetrics[pomegranate]`.' ) real_data, synthetic_data, metadata = cls._validate_inputs( @@ -34,19 +34,25 @@ def _likelihoods(cls, real_data, synthetic_data, metadata=None, structure=None): return np.full(len(real_data), np.nan) LOGGER.debug('Fitting the BayesianNetwork to the real data') - if structure: - if isinstance(structure, dict): - structure = BayesianNetwork.from_json(json.dumps(structure)).structure - - bn = BayesianNetwork.from_structure(real_data[fields].to_numpy(), structure) - else: - bn = BayesianNetwork.from_samples(real_data[fields].to_numpy(), algorithm='chow-liu') - + bn = BayesianNetwork(structure=structure if structure else None, algorithm='chow-liu') + category_to_integer = { + column: { + category: i + for i, category in enumerate( + pd.concat([real_data[column], synthetic_data[column]]).unique() + ) + } + for column in fields + } + real_data[fields] = real_data[fields].replace(category_to_integer).astype('int64') + synthetic_data[fields] = synthetic_data[fields].replace(category_to_integer).astype('int64') + + bn.fit(torch.tensor(real_data[fields].to_numpy())) LOGGER.debug('Evaluating likelihood of the synthetic data') probabilities = [] for _, row in synthetic_data[fields].iterrows(): try: - probabilities.append(bn.probability([row.to_numpy()])) + probabilities.append(bn.probability([row.to_numpy()]).item()) except ValueError: probabilities.append(0) diff --git a/tests/unit/single_table/test_bayesian_network.py b/tests/unit/single_table/test_bayesian_network.py index fc15bc58..b6b46b59 100644 --- a/tests/unit/single_table/test_bayesian_network.py +++ b/tests/unit/single_table/test_bayesian_network.py @@ -1,43 +1,97 @@ -import sys -from unittest.mock import Mock +import re +from unittest.mock import Mock, patch +import numpy as np +import pandas as pd import pytest from sdmetrics.single_table import BNLikelihood, BNLogLikelihood @pytest.fixture -def bad_pomegranate(): - old_pomegranate = getattr(sys.modules, 'pomegranate', None) - sys.modules['pomegranate'] = pytest - yield - if old_pomegranate is not None: - sys.modules['pomegranate'] = old_pomegranate - else: - del sys.modules['pomegranate'] +def real_data(): + return pd.DataFrame({ + 'a': ['a', 'b', 'a', 'b', 'a', 'b'], + 'b': ['c', 'd', 'c', 'd', 'c', 'd'], + 'c': [True, False, True, False, True, False], + 'd': [1, 2, 3, 4, 5, 6], + 'e': [10, 2, 3, 4, 5, 6], + }) + + +@pytest.fixture +def synthetic_data(): + return pd.DataFrame({ + 'a': ['a', 'b', 'b', 'b', 'a', 'b'], + 'b': ['d', 'd', 'c', 'd', 'c', 'd'], + 'c': [False, False, True, False, True, False], + 'd': [4, 2, 3, 4, 5, 6], + 'e': [12, 2, 3, 4, 5, 6], + }) + + +@pytest.fixture +def metadata(): + return { + 'columns': { + 'a': {'sdtype': 'categorical'}, + 'b': {'sdtype': 'categorical'}, + 'c': {'sdtype': 'boolean'}, + 'd': {'sdtype': 'categorical'}, + 'e': {'sdtype': 'numerical'}, + } + } class TestBNLikelihood: - def test_compute(self, bad_pomegranate): - """Test that an ``ImportError`` is raised.""" + @patch.dict('sys.modules', {'pomegranate.bayesian_network': None}) + def test_compute_error(self): + """Test that an `ImportError` is raised.""" # Setup metric = BNLikelihood() - # Act and Assert - expected_message = r'Please install pomegranate with `pip install sdmetrics\[pomegranate\]`\. Python 3\.13 is not supported\.' # noqa: E501 + # Run and Assert + expected_message = re.escape( + 'Please install pomegranate with `pip install sdmetrics[pomegranate]`.' + ) with pytest.raises(ImportError, match=expected_message): metric.compute(Mock(), Mock()) + def test_compute(self, real_data, synthetic_data, metadata): + """Test the ``compute``method.""" + # Setup + np.random.seed(42) + metric = BNLikelihood() + + # Run + result = metric.compute(real_data, synthetic_data, metadata) + + # Assert + assert np.isclose(result, 0.1111111044883728, atol=1e-5) + class TestBNLogLikelihood: - def test_compute(self, bad_pomegranate): - """Test that an ``ImportError`` is raised.""" + @patch.dict('sys.modules', {'pomegranate.bayesian_network': None}) + def test_compute_error(self): + """Test that an `ImportError` is raised.""" # Setup metric = BNLogLikelihood() - # Act and Assert - expected_message = expected_message = ( - r'Please install pomegranate with `pip install sdmetrics\[pomegranate\]`\. Python 3\.13 is not supported\.' # noqa: E501 + # Run and Assert + expected_message = re.escape( + 'Please install pomegranate with `pip install sdmetrics[pomegranate]`.' ) with pytest.raises(ImportError, match=expected_message): metric.compute(Mock(), Mock()) + + def test_compute(self, real_data, synthetic_data, metadata): + """Test the ``compute``method.""" + # Setup + np.random.seed(42) + metric = BNLogLikelihood() + + # Run + result = metric.compute(real_data, synthetic_data, metadata) + + # Assert + assert np.isclose(result, -7.334733486175537, atol=1e-5)