From e91ed28d3bbddd7e7aecef1ebddb48ead91f2b6d Mon Sep 17 00:00:00 2001 From: R-Palazzo Date: Thu, 27 Feb 2025 10:41:42 -0400 Subject: [PATCH 01/12] def --- pyproject.toml | 4 +-- sdmetrics/single_table/bayesian_network.py | 35 ++++++++++++---------- 2 files changed, 21 insertions(+), 18 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index b1696da8..9a767ef2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,7 +60,7 @@ torch = [ "torch>=2.2.0;python_version>='3.12' and python_version<'3.13'", "torch>=2.6.0;python_version>='3.13'", ] -pomegranate = ['pomegranate>=0.15,<1.0'] +pomegranate = ['pomegranate>=1.0,<2.0'] xgboost = ['xgboost>=2.1.3'] test = [ 'sdmetrics[torch]', @@ -74,7 +74,7 @@ test = [ 'pytest-runner>=2.11.1', ] dev = [ - 'sdmetrics[test, xgboost, torch]', + 'sdmetrics[test, xgboost, torch, pomegranate]', # general 'build>=1.0.0,<2', diff --git a/sdmetrics/single_table/bayesian_network.py b/sdmetrics/single_table/bayesian_network.py index 069127a7..0ff078f1 100644 --- a/sdmetrics/single_table/bayesian_network.py +++ b/sdmetrics/single_table/bayesian_network.py @@ -1,9 +1,11 @@ """BayesianNetwork based metrics for single table.""" -import json import logging import numpy as np +import pandas as pd +import torch +from pomegranate.bayesian_network import BayesianNetwork from sdmetrics.goal import Goal from sdmetrics.single_table.base import SingleTableMetric @@ -16,14 +18,6 @@ class BNLikelihoodBase(SingleTableMetric): @classmethod def _likelihoods(cls, real_data, synthetic_data, metadata=None, structure=None): - try: - from pomegranate import BayesianNetwork - except ImportError: - raise ImportError( - 'Please install pomegranate with `pip install sdmetrics[pomegranate]`. ' - 'Python 3.13 is not supported.' - ) - real_data, synthetic_data, metadata = cls._validate_inputs( real_data, synthetic_data, metadata ) @@ -34,14 +28,23 @@ def _likelihoods(cls, real_data, synthetic_data, metadata=None, structure=None): return np.full(len(real_data), np.nan) LOGGER.debug('Fitting the BayesianNetwork to the real data') - if structure: - if isinstance(structure, dict): - structure = BayesianNetwork.from_json(json.dumps(structure)).structure - - bn = BayesianNetwork.from_structure(real_data[fields].to_numpy(), structure) - else: - bn = BayesianNetwork.from_samples(real_data[fields].to_numpy(), algorithm='chow-liu') + bn = BayesianNetwork(structure=structure if structure else None, algorithm='chow-liu') + category_to_integer = { + column: { + category: i + for i, category in enumerate( + pd.concat([real_data[column], synthetic_data[column]]).unique() + ) + } + for column in fields + } + for column in fields: + real_data[column] = real_data[column].map(category_to_integer[column]).astype(int) + synthetic_data[column] = ( + synthetic_data[column].map(category_to_integer[column]).astype(int) + ) + bn.fit(torch.tensor(real_data[fields].to_numpy())) LOGGER.debug('Evaluating likelihood of the synthetic data') probabilities = [] for _, row in synthetic_data[fields].iterrows(): From fbdf5f63eb26583373cab7997490881988259a0f Mon Sep 17 00:00:00 2001 From: R-Palazzo Date: Thu, 27 Feb 2025 10:44:22 -0400 Subject: [PATCH 02/12] tests --- .../single_table/test_bayesian_network.py | 75 ++++++++++++------- 1 file changed, 50 insertions(+), 25 deletions(-) diff --git a/tests/unit/single_table/test_bayesian_network.py b/tests/unit/single_table/test_bayesian_network.py index fc15bc58..d29b1f7f 100644 --- a/tests/unit/single_table/test_bayesian_network.py +++ b/tests/unit/single_table/test_bayesian_network.py @@ -1,43 +1,68 @@ -import sys -from unittest.mock import Mock - +import numpy as np +import pandas as pd import pytest from sdmetrics.single_table import BNLikelihood, BNLogLikelihood @pytest.fixture -def bad_pomegranate(): - old_pomegranate = getattr(sys.modules, 'pomegranate', None) - sys.modules['pomegranate'] = pytest - yield - if old_pomegranate is not None: - sys.modules['pomegranate'] = old_pomegranate - else: - del sys.modules['pomegranate'] +def real_data(): + return pd.DataFrame({ + 'a': ['a', 'b', 'a', 'b', 'a', 'b'], + 'b': ['c', 'd', 'c', 'd', 'c', 'd'], + 'c': [True, False, True, False, True, False], + 'd': [1, 2, 3, 4, 5, 6], + 'e': [10, 2, 3, 4, 5, 6], + }) + + +@pytest.fixture +def synthetic_data(): + return pd.DataFrame({ + 'a': ['a', 'b', 'b', 'b', 'a', 'b'], + 'b': ['d', 'd', 'c', 'd', 'c', 'd'], + 'c': [False, False, True, False, True, False], + 'd': [4, 2, 3, 4, 5, 6], + 'e': [12, 2, 3, 4, 5, 6], + }) + + +@pytest.fixture +def metadata(): + return { + 'columns': { + 'a': {'sdtype': 'categorical'}, + 'b': {'sdtype': 'categorical'}, + 'c': {'sdtype': 'boolean'}, + 'd': {'sdtype': 'categorical'}, + 'e': {'sdtype': 'numerical'}, + } + } class TestBNLikelihood: - def test_compute(self, bad_pomegranate): - """Test that an ``ImportError`` is raised.""" + def test_compute(self, real_data, synthetic_data, metadata): + """Test the metric end to end.""" # Setup + np.random.seed(42) metric = BNLikelihood() - # Act and Assert - expected_message = r'Please install pomegranate with `pip install sdmetrics\[pomegranate\]`\. Python 3\.13 is not supported\.' # noqa: E501 - with pytest.raises(ImportError, match=expected_message): - metric.compute(Mock(), Mock()) + # Run + result = metric.compute(real_data, synthetic_data, metadata) + + # Assert + assert result == 0.111111104 class TestBNLogLikelihood: - def test_compute(self, bad_pomegranate): - """Test that an ``ImportError`` is raised.""" + def test_compute(self, real_data, synthetic_data, metadata): + """Test the ``compute``method.""" # Setup + np.random.seed(42) metric = BNLogLikelihood() - # Act and Assert - expected_message = expected_message = ( - r'Please install pomegranate with `pip install sdmetrics\[pomegranate\]`\. Python 3\.13 is not supported\.' # noqa: E501 - ) - with pytest.raises(ImportError, match=expected_message): - metric.compute(Mock(), Mock()) + # Run + result = metric.compute(real_data, synthetic_data, metadata) + + # Assert + assert result == -7.3347335 From 4c00cf7526096844992bde77c0f84d2a0b047fe6 Mon Sep 17 00:00:00 2001 From: R-Palazzo Date: Thu, 27 Feb 2025 10:57:12 -0400 Subject: [PATCH 03/12] update workflows --- .github/workflows/integration.yml | 8 +------- .github/workflows/minimum.yml | 2 +- .github/workflows/unit.yml | 2 +- tests/unit/single_table/test_bayesian_network.py | 2 +- 4 files changed, 4 insertions(+), 10 deletions(-) diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index e3f506d0..4e7d2e5e 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -32,16 +32,10 @@ jobs: run: | brew install libomp echo 'export DYLD_LIBRARY_PATH=$(brew --prefix libomp)/lib:$DYLD_LIBRARY_PATH' >> $GITHUB_ENV - - if: matrix.python-version != 3.13 - name: Install dependencies + - name: Install dependencies run: | python -m pip install --upgrade pip python -m pip install invoke .[pomegranate,xgboost,test] - - if: matrix.python-version == 3.13 - name: Install dependencies - run: | - python -m pip install --upgrade pip - python -m pip install invoke .[xgboost,test] - name: Run integration tests run: invoke integration diff --git a/.github/workflows/minimum.yml b/.github/workflows/minimum.yml index c65e9e6d..738135c1 100644 --- a/.github/workflows/minimum.yml +++ b/.github/workflows/minimum.yml @@ -35,6 +35,6 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - python -m pip install invoke .[test,xgboost] + python -m pip install invoke .[pomegranate,xgboost,test] - name: Test with minimum versions run: invoke minimum diff --git a/.github/workflows/unit.yml b/.github/workflows/unit.yml index b043f6d1..91a14ab6 100644 --- a/.github/workflows/unit.yml +++ b/.github/workflows/unit.yml @@ -35,7 +35,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - python -m pip install invoke .[test,xgboost] + python -m pip install invoke .[pomegranate,xgboost,test] - name: Run unit tests run: invoke unit diff --git a/tests/unit/single_table/test_bayesian_network.py b/tests/unit/single_table/test_bayesian_network.py index d29b1f7f..433b25e3 100644 --- a/tests/unit/single_table/test_bayesian_network.py +++ b/tests/unit/single_table/test_bayesian_network.py @@ -42,7 +42,7 @@ def metadata(): class TestBNLikelihood: def test_compute(self, real_data, synthetic_data, metadata): - """Test the metric end to end.""" + """Test the ``compute``method.""" # Setup np.random.seed(42) metric = BNLikelihood() From f829f7f795157a3914a6bff5d1b4aef8d21a447a Mon Sep 17 00:00:00 2001 From: R-Palazzo Date: Thu, 27 Feb 2025 13:44:09 -0400 Subject: [PATCH 04/12] use torch.tensor --- sdmetrics/single_table/bayesian_network.py | 8 ++++---- tests/unit/single_table/test_bayesian_network.py | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/sdmetrics/single_table/bayesian_network.py b/sdmetrics/single_table/bayesian_network.py index 0ff078f1..3a40b272 100644 --- a/sdmetrics/single_table/bayesian_network.py +++ b/sdmetrics/single_table/bayesian_network.py @@ -49,9 +49,9 @@ def _likelihoods(cls, real_data, synthetic_data, metadata=None, structure=None): probabilities = [] for _, row in synthetic_data[fields].iterrows(): try: - probabilities.append(bn.probability([row.to_numpy()])) + probabilities.append(torch.tensor(bn.probability([row.to_numpy()]))) except ValueError: - probabilities.append(0) + probabilities.append(torch.tensor(0)) return np.asarray(probabilities) @@ -125,7 +125,7 @@ def compute(cls, real_data, synthetic_data, metadata=None, structure=None): float: Mean of the probabilities returned by the Bayesian Network. """ - return np.mean(cls._likelihoods(real_data, synthetic_data, metadata, structure)) + return np.mean(cls._likelihoods(real_data, synthetic_data, metadata, structure)).item() class BNLogLikelihood(BNLikelihoodBase): @@ -199,7 +199,7 @@ def compute(cls, real_data, synthetic_data, metadata=None, structure=None): """ likelihoods = cls._likelihoods(real_data, synthetic_data, metadata, structure) likelihoods[np.where(likelihoods == 0)] = 1e-8 - return np.mean(np.log(likelihoods)) + return np.mean(np.log(likelihoods)).item() @classmethod def normalize(cls, raw_score): diff --git a/tests/unit/single_table/test_bayesian_network.py b/tests/unit/single_table/test_bayesian_network.py index 433b25e3..95a38e67 100644 --- a/tests/unit/single_table/test_bayesian_network.py +++ b/tests/unit/single_table/test_bayesian_network.py @@ -51,7 +51,7 @@ def test_compute(self, real_data, synthetic_data, metadata): result = metric.compute(real_data, synthetic_data, metadata) # Assert - assert result == 0.111111104 + assert result == 0.1111111044883728 class TestBNLogLikelihood: @@ -65,4 +65,4 @@ def test_compute(self, real_data, synthetic_data, metadata): result = metric.compute(real_data, synthetic_data, metadata) # Assert - assert result == -7.3347335 + assert result == -7.334733486175537 From 8f133a7be97900f4b7415922f682c07cb65b14f7 Mon Sep 17 00:00:00 2001 From: R-Palazzo Date: Thu, 27 Feb 2025 13:51:31 -0400 Subject: [PATCH 05/12] update pomegranate version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 9a767ef2..33f82c6f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,7 +60,7 @@ torch = [ "torch>=2.2.0;python_version>='3.12' and python_version<'3.13'", "torch>=2.6.0;python_version>='3.13'", ] -pomegranate = ['pomegranate>=1.0,<2.0'] +pomegranate = ['pomegranate>=1.1.1,<2.0'] xgboost = ['xgboost>=2.1.3'] test = [ 'sdmetrics[torch]', From b68060b992a9a85730497a100f0b65af0ad145fe Mon Sep 17 00:00:00 2001 From: R-Palazzo Date: Thu, 27 Feb 2025 14:35:38 -0400 Subject: [PATCH 06/12] fix 3.8 workflows + minimum tests --- pyproject.toml | 2 +- sdmetrics/single_table/bayesian_network.py | 8 ++++---- tests/unit/single_table/test_bayesian_network.py | 4 ++-- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 33f82c6f..c0dceea1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,7 @@ license = { text = 'MIT license' } requires-python = ">=3.8,<3.14" readme = 'README.md' dependencies = [ - "numpy>=1.21.0;python_version<'3.10'", + "numpy>=1.22.2;python_version<'3.10'", "numpy>=1.23.3;python_version>='3.10' and python_version<'3.12'", "numpy>=1.26.0;python_version>='3.12' and python_version<'3.13'", "numpy>=2.1.0;python_version>='3.13'", diff --git a/sdmetrics/single_table/bayesian_network.py b/sdmetrics/single_table/bayesian_network.py index 3a40b272..2122a8f8 100644 --- a/sdmetrics/single_table/bayesian_network.py +++ b/sdmetrics/single_table/bayesian_network.py @@ -49,9 +49,9 @@ def _likelihoods(cls, real_data, synthetic_data, metadata=None, structure=None): probabilities = [] for _, row in synthetic_data[fields].iterrows(): try: - probabilities.append(torch.tensor(bn.probability([row.to_numpy()]))) + probabilities.append(bn.probability([row.to_numpy()]).item()) except ValueError: - probabilities.append(torch.tensor(0)) + probabilities.append(0) return np.asarray(probabilities) @@ -125,7 +125,7 @@ def compute(cls, real_data, synthetic_data, metadata=None, structure=None): float: Mean of the probabilities returned by the Bayesian Network. """ - return np.mean(cls._likelihoods(real_data, synthetic_data, metadata, structure)).item() + return np.mean(cls._likelihoods(real_data, synthetic_data, metadata, structure)) class BNLogLikelihood(BNLikelihoodBase): @@ -199,7 +199,7 @@ def compute(cls, real_data, synthetic_data, metadata=None, structure=None): """ likelihoods = cls._likelihoods(real_data, synthetic_data, metadata, structure) likelihoods[np.where(likelihoods == 0)] = 1e-8 - return np.mean(np.log(likelihoods)).item() + return np.mean(np.log(likelihoods)) @classmethod def normalize(cls, raw_score): diff --git a/tests/unit/single_table/test_bayesian_network.py b/tests/unit/single_table/test_bayesian_network.py index 95a38e67..6a9f947a 100644 --- a/tests/unit/single_table/test_bayesian_network.py +++ b/tests/unit/single_table/test_bayesian_network.py @@ -51,7 +51,7 @@ def test_compute(self, real_data, synthetic_data, metadata): result = metric.compute(real_data, synthetic_data, metadata) # Assert - assert result == 0.1111111044883728 + assert np.isclose(result, 0.1111111044883728, atol=1e-5) class TestBNLogLikelihood: @@ -65,4 +65,4 @@ def test_compute(self, real_data, synthetic_data, metadata): result = metric.compute(real_data, synthetic_data, metadata) # Assert - assert result == -7.334733486175537 + assert np.isclose(result, -7.334733486175537, atol=1e-5) From cf7568e4f4492d0d61d007d3ac6350879d75f963 Mon Sep 17 00:00:00 2001 From: R-Palazzo Date: Thu, 27 Feb 2025 14:55:33 -0400 Subject: [PATCH 07/12] fix minimum version 2 (python 3.10) --- pyproject.toml | 2 +- sdmetrics/single_table/bayesian_network.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index c0dceea1..1cff70d5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,7 @@ requires-python = ">=3.8,<3.14" readme = 'README.md' dependencies = [ "numpy>=1.22.2;python_version<'3.10'", - "numpy>=1.23.3;python_version>='3.10' and python_version<'3.12'", + "numpy>=1.24.0;python_version>='3.10' and python_version<'3.12'", "numpy>=1.26.0;python_version>='3.12' and python_version<'3.13'", "numpy>=2.1.0;python_version>='3.13'", "pandas>=1.4.0;python_version<'3.11'", diff --git a/sdmetrics/single_table/bayesian_network.py b/sdmetrics/single_table/bayesian_network.py index 2122a8f8..b3410a15 100644 --- a/sdmetrics/single_table/bayesian_network.py +++ b/sdmetrics/single_table/bayesian_network.py @@ -39,7 +39,7 @@ def _likelihoods(cls, real_data, synthetic_data, metadata=None, structure=None): for column in fields } for column in fields: - real_data[column] = real_data[column].map(category_to_integer[column]).astype(int) + real_data[column] = real_data[column].map(category_to_integer[column]).astype('int64') synthetic_data[column] = ( synthetic_data[column].map(category_to_integer[column]).astype(int) ) From af2257baa2931c8f1beb8b0f928b0141a013d530 Mon Sep 17 00:00:00 2001 From: R-Palazzo Date: Fri, 28 Feb 2025 11:49:56 -0400 Subject: [PATCH 08/12] upgrade to 1.1.2 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 1cff70d5..ad36495a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,7 +60,7 @@ torch = [ "torch>=2.2.0;python_version>='3.12' and python_version<'3.13'", "torch>=2.6.0;python_version>='3.13'", ] -pomegranate = ['pomegranate>=1.1.1,<2.0'] +pomegranate = ['pomegranate>=1.1.2,<2.0'] xgboost = ['xgboost>=2.1.3'] test = [ 'sdmetrics[torch]', From 208c92a703b37e143a63eb3fd5f2232848160f1a Mon Sep 17 00:00:00 2001 From: R-Palazzo Date: Fri, 28 Feb 2025 15:24:44 -0400 Subject: [PATCH 09/12] add import error handling --- sdmetrics/single_table/bayesian_network.py | 8 ++++- .../single_table/test_bayesian_network.py | 29 +++++++++++++++++++ 2 files changed, 36 insertions(+), 1 deletion(-) diff --git a/sdmetrics/single_table/bayesian_network.py b/sdmetrics/single_table/bayesian_network.py index b3410a15..0727de82 100644 --- a/sdmetrics/single_table/bayesian_network.py +++ b/sdmetrics/single_table/bayesian_network.py @@ -5,7 +5,6 @@ import numpy as np import pandas as pd import torch -from pomegranate.bayesian_network import BayesianNetwork from sdmetrics.goal import Goal from sdmetrics.single_table.base import SingleTableMetric @@ -18,6 +17,13 @@ class BNLikelihoodBase(SingleTableMetric): @classmethod def _likelihoods(cls, real_data, synthetic_data, metadata=None, structure=None): + try: + from pomegranate.bayesian_network import BayesianNetwork + except ImportError: + raise ImportError( + 'Please install pomegranate with `pip install sdmetrics[pomegranate]`.' + ) + real_data, synthetic_data, metadata = cls._validate_inputs( real_data, synthetic_data, metadata ) diff --git a/tests/unit/single_table/test_bayesian_network.py b/tests/unit/single_table/test_bayesian_network.py index 6a9f947a..b6b46b59 100644 --- a/tests/unit/single_table/test_bayesian_network.py +++ b/tests/unit/single_table/test_bayesian_network.py @@ -1,3 +1,6 @@ +import re +from unittest.mock import Mock, patch + import numpy as np import pandas as pd import pytest @@ -41,6 +44,19 @@ def metadata(): class TestBNLikelihood: + @patch.dict('sys.modules', {'pomegranate.bayesian_network': None}) + def test_compute_error(self): + """Test that an `ImportError` is raised.""" + # Setup + metric = BNLikelihood() + + # Run and Assert + expected_message = re.escape( + 'Please install pomegranate with `pip install sdmetrics[pomegranate]`.' + ) + with pytest.raises(ImportError, match=expected_message): + metric.compute(Mock(), Mock()) + def test_compute(self, real_data, synthetic_data, metadata): """Test the ``compute``method.""" # Setup @@ -55,6 +71,19 @@ def test_compute(self, real_data, synthetic_data, metadata): class TestBNLogLikelihood: + @patch.dict('sys.modules', {'pomegranate.bayesian_network': None}) + def test_compute_error(self): + """Test that an `ImportError` is raised.""" + # Setup + metric = BNLogLikelihood() + + # Run and Assert + expected_message = re.escape( + 'Please install pomegranate with `pip install sdmetrics[pomegranate]`.' + ) + with pytest.raises(ImportError, match=expected_message): + metric.compute(Mock(), Mock()) + def test_compute(self, real_data, synthetic_data, metadata): """Test the ``compute``method.""" # Setup From 9d9dc59f27c26e7b458a2388fa53ccb94873398b Mon Sep 17 00:00:00 2001 From: R-Palazzo Date: Fri, 28 Feb 2025 15:25:03 -0400 Subject: [PATCH 10/12] update pyproject and workflows --- .github/workflows/integration.yml | 2 +- .github/workflows/minimum.yml | 2 +- .github/workflows/unit.yml | 2 +- pyproject.toml | 4 +++- 4 files changed, 6 insertions(+), 4 deletions(-) diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index 4e7d2e5e..6f8f772b 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -35,7 +35,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - python -m pip install invoke .[pomegranate,xgboost,test] + python -m pip install invoke .[test] - name: Run integration tests run: invoke integration diff --git a/.github/workflows/minimum.yml b/.github/workflows/minimum.yml index 738135c1..8a666061 100644 --- a/.github/workflows/minimum.yml +++ b/.github/workflows/minimum.yml @@ -35,6 +35,6 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - python -m pip install invoke .[pomegranate,xgboost,test] + python -m pip install invoke .[test] - name: Test with minimum versions run: invoke minimum diff --git a/.github/workflows/unit.yml b/.github/workflows/unit.yml index 91a14ab6..136474f4 100644 --- a/.github/workflows/unit.yml +++ b/.github/workflows/unit.yml @@ -35,7 +35,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - python -m pip install invoke .[pomegranate,xgboost,test] + python -m pip install invoke .[test] - name: Run unit tests run: invoke unit diff --git a/pyproject.toml b/pyproject.toml index ad36495a..f6da43f9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -63,6 +63,8 @@ torch = [ pomegranate = ['pomegranate>=1.1.2,<2.0'] xgboost = ['xgboost>=2.1.3'] test = [ + 'xgboost>=2.1.3', + 'pomegranate>=1.1.2,<2.0', 'sdmetrics[torch]', 'pytest>=6.2.5,<7', 'pytest-cov>=2.6.0,<3', @@ -74,7 +76,7 @@ test = [ 'pytest-runner>=2.11.1', ] dev = [ - 'sdmetrics[test, xgboost, torch, pomegranate]', + 'sdmetrics[test, torch]', # general 'build>=1.0.0,<2', From 5568431e7ecc2b5a6f0a7e6f31496de2b93759d3 Mon Sep 17 00:00:00 2001 From: R-Palazzo Date: Fri, 28 Feb 2025 15:58:45 -0400 Subject: [PATCH 11/12] update pyproject --- pyproject.toml | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f6da43f9..ed47ede8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -63,9 +63,7 @@ torch = [ pomegranate = ['pomegranate>=1.1.2,<2.0'] xgboost = ['xgboost>=2.1.3'] test = [ - 'xgboost>=2.1.3', - 'pomegranate>=1.1.2,<2.0', - 'sdmetrics[torch]', + 'sdmetrics[pomegranate,torch,xgboost]', 'pytest>=6.2.5,<7', 'pytest-cov>=2.6.0,<3', 'pytest-rerunfailures>=10.3,<15', From 568b7227d2121df64da7d275d5b6b4facb197fa4 Mon Sep 17 00:00:00 2001 From: R-Palazzo Date: Mon, 10 Mar 2025 10:44:00 -0400 Subject: [PATCH 12/12] use replace() --- sdmetrics/single_table/bayesian_network.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/sdmetrics/single_table/bayesian_network.py b/sdmetrics/single_table/bayesian_network.py index 0727de82..b5d69f52 100644 --- a/sdmetrics/single_table/bayesian_network.py +++ b/sdmetrics/single_table/bayesian_network.py @@ -44,11 +44,8 @@ def _likelihoods(cls, real_data, synthetic_data, metadata=None, structure=None): } for column in fields } - for column in fields: - real_data[column] = real_data[column].map(category_to_integer[column]).astype('int64') - synthetic_data[column] = ( - synthetic_data[column].map(category_to_integer[column]).astype(int) - ) + real_data[fields] = real_data[fields].replace(category_to_integer).astype('int64') + synthetic_data[fields] = synthetic_data[fields].replace(category_to_integer).astype('int64') bn.fit(torch.tensor(real_data[fields].to_numpy())) LOGGER.debug('Evaluating likelihood of the synthetic data')