Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH add Naive Bayes Metaestimator ColumnwiseNB (aka "GeneralNB") #22574

Open
wants to merge 218 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 189 commits
Commits
Show all changes
218 commits
Select commit Hold shift + click to select a range
75ce0de
Add abstract methods to _BaseDiscreteNB and minor corrections in comm…
Feb 21, 2022
6f80025
Merge remote-tracking branch 'upstream/main' into naive-bayes-abstract
Feb 22, 2022
62ebbe0
Implemented ColumnwiseNB and tests
Feb 22, 2022
4ca9ac5
Added my name to module authors.
Feb 22, 2022
d7a9bf4
ColumnwiseNB docstring correction, See Also, Example. Added example t…
Feb 22, 2022
2dfcd40
black formatting compliance.
Feb 22, 2022
861a573
black formatting compliance.
Feb 22, 2022
9b73275
ColumnwiseNB: added _required_parameters = [estimators]
Feb 22, 2022
fa66575
Dirty trick with ColumnwiseNB._required_parameters to pass tests
Feb 22, 2022
12677e7
ColumnwiseNB docstring: added extended summary
Feb 23, 2022
e748cf1
ColumnwiseNB test issue: added to VALIDATE_ESTIMATOR_INIT exclusion list
Feb 23, 2022
8b7cfc8
ColumnwiseNB: rename 'estimators' into 'estimatorNBs'
Feb 23, 2022
7cb8430
Rename 'estimators' into 'estimatorNBs' in test_naive_bayes.py
Feb 23, 2022
2abac32
ColumnwiseNB: rename 'estimators' in the example too
Feb 23, 2022
c9a5e1d
Added pytest skip when no pandas to test_naive_bayes.py
Feb 23, 2022
79f980f
ColumnwiseNB: update class prior AFTER update fitted estimators
Feb 23, 2022
c2cd353
test_naive_bayes.py Added more tests to improve coverage
Feb 23, 2022
7ac5986
Merge branch 'main' into ColumnwiseNB
Feb 27, 2022
a4939f2
Merge remote-tracking branch 'upstream/main' into ColumnwiseNB
Feb 27, 2022
de1123b
Add DOC entry and corrections to DOCSTRING
Feb 27, 2022
e515b6a
ColumnwiseNB: DOCSTRING correction
Feb 27, 2022
b120499
Merge remote-tracking branch 'upstream/main' into ColumnwiseNB
Feb 27, 2022
9dbbb4d
Merge remote-tracking branch 'upstream/main' into ColumnwiseNB
Feb 28, 2022
8c86193
Merge remote-tracking branch 'upstream/main' into ColumnwiseNB
Mar 5, 2022
cf414c5
Merge branch 'scikit-learn:main' into ColumnwiseNB
avm19 Mar 7, 2022
0c64f2e
Merge branch 'scikit-learn:main' into ColumnwiseNB
avm19 Mar 9, 2022
4ccf640
Merge branch 'scikit-learn:main' into ColumnwiseNB
avm19 Mar 10, 2022
94041f5
Merge branch 'scikit-learn:main' into ColumnwiseNB
avm19 Mar 12, 2022
99d3793
Merge branch 'scikit-learn:main' into ColumnwiseNB
avm19 Mar 14, 2022
6cc8ce4
Merge branch 'scikit-learn:main' into ColumnwiseNB
avm19 Mar 14, 2022
3e7f53a
Merge branch 'scikit-learn:main' into ColumnwiseNB
avm19 Mar 18, 2022
c70c675
Merge branch 'scikit-learn:main' into ColumnwiseNB
avm19 Mar 20, 2022
8c90d29
Merge branch 'scikit-learn:main' into ColumnwiseNB
avm19 Mar 21, 2022
a74fd15
Merge branch 'scikit-learn:main' into ColumnwiseNB asdf
avm19 Mar 22, 2022
2cc6d5a
Merge branch 'scikit-learn:main' into ColumnwiseNB
avm19 Mar 23, 2022
c8f09cd
Merge branch 'scikit-learn:main' into ColumnwiseNB
avm19 Mar 24, 2022
6033041
Merge remote-tracking branch 'origin/main' into ColumnwiseNB
Mar 31, 2022
eba339f
Replace ColumnwiseNB exception from init test in VALIDATE_ESTIMATOR_INIT
Apr 1, 2022
0dbb07f
Reformatting to comply with black=22.3.0
Apr 1, 2022
5add60b
Fixing the init and set_params test. Cf. #22537
Apr 1, 2022
f0fd8b3
ColumnwiseNB: rename estimatorsNBs to nb_estimators
Apr 1, 2022
1c884f3
tests for ColumnwiseNB: rename estimatorsNBs to nb_estimators
Apr 1, 2022
0dab95b
flake8 fix in text_naive_bayes.py
Apr 1, 2022
1bd1059
black fix in test_naive_bayes.py
Apr 1, 2022
bccf49b
Merge branch 'scikit-learn:main' into ColumnwiseNB
avm19 Apr 2, 2022
081114f
Merge branch 'main' of github.com:scikit-learn/scikit-learn into Colu…
Apr 7, 2022
5979e7b
Merge branch 'scikit-learn:main' into ColumnwiseNB
avm19 Apr 8, 2022
fcf032c
Added example: ColumnwiseNB for titanic dataset
Apr 9, 2022
d3d606f
Merge branch 'ColumnwiseNB' of github.com:avm19/scikit-learn into Col…
Apr 9, 2022
1966cca
Merge branch 'scikit-learn:main' into ColumnwiseNB
avm19 Apr 9, 2022
e2e3ccf
flake8 fix
Apr 9, 2022
75ce7af
ColumnwiseNB: added _check_n_features to fit and partial_fit
Apr 9, 2022
2d706bd
Correction to the example (ColumnwiseNB for titanic dataset)
Apr 10, 2022
4203f7b
Added a section to the naive bayes guide in documentation
Apr 10, 2022
ffc4b60
CI fix: try n_retires=10 in fetch_openml
Apr 10, 2022
b5628ee
Merge branch 'main' of github.com:scikit-learn/scikit-learn into Colu…
Apr 10, 2022
31b47a0
Fix formatting in the gallery example
Apr 10, 2022
64adf79
Fix formatting in the gallery example
Apr 10, 2022
f6eaab7
Improve documentation and the gallery example
Apr 10, 2022
a40fdca
Re #21355 'no validation at init'-test: Remove the logic at setter, k…
Apr 11, 2022
5492c60
Merge remote-tracking branch 'upstream/main' into ColumnwiseNB
Apr 11, 2022
40085e5
Add test for error when subestimator does not compute class priors, b…
Apr 11, 2022
019c0e0
Extend the test for class priors extraction to MultinomialNB to cover…
Apr 11, 2022
197740d
Add ColumnwiseNB._sk_visual_block_ method for better HTML representation
Apr 11, 2022
b275993
Add test for ColumnwiseNB._sk_visual_block()
Apr 11, 2022
5a8d59e
Black formatting correction
Apr 11, 2022
be83512
Tests for ColumnwiseNB priors: Remove unnecessary definitions
Apr 11, 2022
50ab0b7
Merge branch 'scikit-learn:main' into ColumnwiseNB
avm19 Apr 12, 2022
2e09146
Merge branch 'main' of github.com:scikit-learn/scikit-learn into Colu…
Apr 12, 2022
b93f39d
Change log: add an entry for ColumnwiseNB
Apr 12, 2022
3a82e8d
Merge branch 'scikit-learn:main' into ColumnwiseNB
avm19 Apr 14, 2022
7aa60ec
Merge branch 'scikit-learn:main' into ColumnwiseNB
avm19 Apr 16, 2022
4fab0e9
Merge branch 'scikit-learn:main' into ColumnwiseNB
avm19 Apr 18, 2022
311d134
Merge branch 'scikit-learn:main' into ColumnwiseNB
avm19 Apr 20, 2022
7eb9e2f
Merge branch 'scikit-learn:main' into ColumnwiseNB
avm19 Apr 21, 2022
8eabb30
Merge branch 'scikit-learn:main' into ColumnwiseNB
avm19 Apr 22, 2022
b91ed02
Merge branch 'scikit-learn:main' into ColumnwiseNB
avm19 Apr 22, 2022
595bb58
Merge branch 'scikit-learn:main' into ColumnwiseNB
avm19 Apr 23, 2022
ff4be14
Merge branch 'scikit-learn:main' into ColumnwiseNB
avm19 Apr 28, 2022
616897e
Merge branch 'scikit-learn:main' into ColumnwiseNB
avm19 May 1, 2022
e423db9
Changelog entry moved from v1.1.rst to v1.2.rst
May 1, 2022
199cbeb
Merge upstream
May 6, 2022
e86757d
Merge branch 'scikit-learn:main' into ColumnwiseNB
avm19 May 10, 2022
de7db85
Merge remote-tracking branch 'upstream/main' into ColumnwiseNB
May 18, 2022
6fdd100
Merge branch 'ColumnwiseNB' of github.com:avm19/scikit-learn into Col…
May 18, 2022
f3ecdfb
Merge branch 'scikit-learn:main' into ColumnwiseNB
avm19 May 21, 2022
fb4c9c8
Merge branch 'scikit-learn:main' into ColumnwiseNB
avm19 May 29, 2022
8ee2cd2
Merge branch 'scikit-learn:main' into ColumnwiseNB
avm19 Jun 2, 2022
1047ad5
Merge remote-tracking branch 'upstream/main' into ColumnwiseNB
Jun 4, 2022
261bc5d
Format cited code in docstring sklearn/naive_bayes.py
avm19 Jun 4, 2022
c7334a2
Merge branch 'ColumnwiseNB' of github.com:avm19/scikit-learn into Col…
Jun 4, 2022
c19e76b
Remove unnecessary import in test.naive_bayes.py::test_cwnb_example
Jun 4, 2022
1eb5d49
Update authors in examples/miscellaneous/plot_combining_naive_bayes.py
avm19 Jun 4, 2022
7a8ee18
Split test functions and give better names in test_naive_bayes.py
Jun 4, 2022
03e48e6
Namechange and minor comments in test_naive_bayes.py
Jun 5, 2022
e1449a0
Test union GaussianNBs matches single one when priors are specified
Jun 5, 2022
e60cd80
Merge remote-tracking branch 'upstream/main' into ColumnwiseNB
Jun 18, 2022
0b1829c
Implement _BaseNB.predict_joint_log_proba method and test for it
Jun 18, 2022
ec9033e
_BaseNB.predict_join_log_proba improve docstring
Jun 18, 2022
9c721ff
Merge branch 'naive-bayes-jll' into ColumnwiseNB
Jun 18, 2022
c321f63
Changelog entry
Jun 18, 2022
c65bbc7
Merge branch 'naive-bayes-jll' into ColumnwiseNB
Jun 18, 2022
2be313e
Use predict_joint_log_proba instead of _joint_log_likelihood in sub-e…
Jun 18, 2022
0ba8cb5
Merge branch 'scikit-learn:main' into ColumnwiseNB
avm19 Jun 26, 2022
e4df039
Merge branch 'main' of github.com:scikit-learn/scikit-learn into Colu…
Jul 7, 2022
fe6fab8
Merge branch 'ColumnwiseNB' of github.com:avm19/scikit-learn into Col…
Jul 7, 2022
aae81e9
Common parameter validation towards #23462 and custom test
Jul 7, 2022
4ba7bac
Docs terminology log-likelihood -> log-probability
Jul 7, 2022
f1b2d02
Merge remote-tracking branch 'upstream/main' into ColumnwiseNB
Jul 25, 2022
389bad1
Merge branch 'scikit-learn:main' into ColumnwiseNB
avm19 Aug 12, 2022
1224e9a
Empty commit to trigger pipeline
Aug 12, 2022
433e071
Merge branch 'scikit-learn:main' into ColumnwiseNB
avm19 Aug 21, 2022
67418a1
Merge branch 'scikit-learn:main' into ColumnwiseNB
avm19 Sep 2, 2022
f3015a4
Merge branch 'scikit-learn:main' into ColumnwiseNB
avm19 Sep 6, 2022
8850648
Merge branch 'scikit-learn:main' into ColumnwiseNB
avm19 Sep 23, 2022
2004fa8
Merge branch 'scikit-learn:main' into ColumnwiseNB
avm19 Sep 30, 2022
9c840f8
Merge branch 'scikit-learn:main' into ColumnwiseNB
avm19 Oct 7, 2022
0e5cfd2
Merge branch 'scikit-learn:main' into ColumnwiseNB
avm19 Oct 13, 2022
2779e07
Merge branch 'scikit-learn:main' into ColumnwiseNB
avm19 Oct 13, 2022
397cc1b
Parameter parser='auto' in fetch_openml. See #21938
Oct 18, 2022
33622ae
Use set_config(transform_output=pandas) and string feature names. See…
Oct 18, 2022
ef69554
Merge branch 'scikit-learn:main' into ColumnwiseNB
avm19 Oct 27, 2022
dbfac78
Merge branch 'scikit-learn:main' into ColumnwiseNB
avm19 Nov 6, 2022
a8f9413
Merge branch 'scikit-learn:main' into ColumnwiseNB
avm19 Nov 11, 2022
92beb26
Merge branch 'scikit-learn:main' into ColumnwiseNB
avm19 Nov 29, 2022
e7ef3e9
Merge branch 'scikit-learn:main' into ColumnwiseNB
avm19 Dec 8, 2022
3e63954
Merge branch 'scikit-learn:main' into ColumnwiseNB
avm19 Dec 18, 2022
9d8161d
Merge branch 'scikit-learn:main' into ColumnwiseNB
avm19 Dec 27, 2022
c36f36a
Merge branch 'main' into ColumnwiseNB
glemaitre Dec 28, 2022
a5ea324
DOC update changelog
glemaitre Dec 28, 2022
446cd5e
Merge remote-tracking branch 'upstream/main' into ColumnwiseNB
Dec 29, 2022
3743b17
Minor format suggestions from glemaitre's review
avm19 Dec 29, 2022
df70750
Merge remote-tracking branch 'upstream/main' into ColumnwiseNB
Dec 29, 2022
75a00f4
Docstring: versionadded note and a reference to the User Guide
Dec 29, 2022
e70e6c8
Docstring for n_features_in_
Dec 29, 2022
df49fb4
Merge remote-tracking branch 'upstream/main' into ColumnwiseNB
Dec 30, 2022
3291276
Merge remote-tracking branch 'upstream/main' into ColumnwiseNB
Jan 1, 2023
f19914d
Remove n_classes_ attribute
Jan 1, 2023
23469d7
named_estimators_ is now not a property, but a field
Jan 1, 2023
9683a2b
Minor formatting: f-string in place of old style
Jan 1, 2023
03ef84d
Factor out _fit_partial from fit and fit_partial
Jan 2, 2023
2b8bde7
Merge remote-tracking branch 'upstream/main' into ColumnwiseNB
Jan 2, 2023
1198f25
Ensure ColumnwiseNB.class_count_ is float64
Jan 2, 2023
983a3a6
Merge remote-tracking branch 'upstream/main' into ColumnwiseNB
Jan 3, 2023
21ae456
Docstring: clarify callable columns are evaluated only once
Jan 3, 2023
51b3a39
Decorate partial_fit with available_if
Jan 4, 2023
0eb36a9
Improve _validate_estimators for when non-tuples are passed
Jan 4, 2023
8d27d6e
Improve _validate_estimators and test
Jan 4, 2023
76b32ad
Merge remote-tracking branch 'upstream/main' into ColumnwiseNB
Jan 4, 2023
298f3ae
black formatting
Jan 4, 2023
09eaf7a
Use .utils._encode._unique instead of np.unique
Jan 4, 2023
b38206b
Docstring: replace double backticks with single backticks
Jan 4, 2023
844afec
Remove and/or correct comments
Jan 5, 2023
77313a3
Merge remote-tracking branch 'upstream/main' into ColumnwiseNB
Jan 5, 2023
ac67bee
Merge remote-tracking branch 'upstream/main' into ColumnwiseNB
Jan 6, 2023
75b2cfc
Correct mistake that fit() does not fit from scratch. Test
Jan 6, 2023
a3a0653
Common tests and estimator checks for ColumnwiseNB (more)
Jan 6, 2023
5201045
Pass tests by removing memory address in column selector __repr__
Jan 7, 2023
043dbd6
Empty commit to trigger build pipeline
Jan 8, 2023
d077501
Merge remote-tracking branch 'upstream/main' into ColumnwiseNB
Jan 14, 2023
23ec142
Merge branch 'scikit-learn:main' into ColumnwiseNB
avm19 Jan 20, 2023
5434ef3
Merge branch 'scikit-learn:main' into ColumnwiseNB
avm19 Jan 26, 2023
07d4b9d
Merge remote-tracking branch 'upstream/main' into ColumnwiseNB
Jan 27, 2023
755b459
Merge branch 'ColumnwiseNB' of github.com:avm19/scikit-learn into Col…
Jan 27, 2023
6c5c7d8
Use utils.parallel.delayed not utils.fixes.delayed
Jan 27, 2023
7d6c837
Merge branch 'scikit-learn:main' into ColumnwiseNB
avm19 Feb 6, 2023
d849fdf
Merge branch 'scikit-learn:main' into ColumnwiseNB
avm19 Feb 15, 2023
6a69cc2
Merge branch 'scikit-learn:main' into ColumnwiseNB
avm19 Feb 19, 2023
58613ed
Merge branch 'scikit-learn:main' into ColumnwiseNB
avm19 Feb 27, 2023
0f36fe6
Merge branch 'scikit-learn:main' into ColumnwiseNB
avm19 Mar 5, 2023
8feddf6
Merge branch 'scikit-learn:main' into ColumnwiseNB
avm19 Mar 11, 2023
be5cba5
Merge branch 'scikit-learn:main' into ColumnwiseNB
avm19 Mar 20, 2023
937c857
Merge branch 'scikit-learn:main' into ColumnwiseNB
avm19 Mar 25, 2023
2e1b3e4
Merge remote-tracking branch 'upstream/main' into ColumnwiseNB
Mar 26, 2023
a8a07c4
TST use global_random_seed towards #22827
Mar 27, 2023
70f38fb
Trigger build
Mar 27, 2023
a9051bb
Merge branch 'scikit-learn:main' into ColumnwiseNB
avm19 Mar 31, 2023
b63fd28
Merge branch 'scikit-learn:main' into ColumnwiseNB
avm19 Apr 8, 2023
0159a80
Merge branch 'scikit-learn:main' into ColumnwiseNB
avm19 Apr 15, 2023
d34e387
Merge branch 'scikit-learn:main' into ColumnwiseNB
avm19 Apr 24, 2023
764013c
Merge branch 'scikit-learn:main' into ColumnwiseNB
avm19 Apr 25, 2023
65b1e50
Merge branch 'scikit-learn:main' into ColumnwiseNB
avm19 May 16, 2023
9d6eb91
Merge branch 'scikit-learn:main' into ColumnwiseNB
avm19 May 22, 2023
4df6ff5
Merge branch 'scikit-learn:main' into ColumnwiseNB
avm19 Jun 11, 2023
f8277ba
Minor typo in a comment
avm19 Jun 15, 2023
4aa9b83
Add np.where to cover the possibility of zero prior
avm19 Jun 15, 2023
d14dae9
black formatting
Jun 15, 2023
d30ac13
Merge remote-tracking branch 'upstream/main' into ColumnwiseNB
Jun 15, 2023
8ea37ff
Reformatting to comply with black=23.3.0
Jun 15, 2023
cf69151
Merge remote-tracking branch 'upstream/main' into ColumnwiseNB
Jun 15, 2023
dbbeaf5
Apply _fit_context decorator. Cf. #26473
Jun 15, 2023
ff02ae8
black formatting
Jun 15, 2023
923ec6e
Merge remote-tracking branch 'upstream/main' into ColumnwiseNB
Jun 26, 2023
cb5a43d
Formatting ruff
Jun 26, 2023
b8cd150
Merge branch 'scikit-learn:main' into ColumnwiseNB
avm19 Jun 27, 2023
af6df1a
Merge remote-tracking branch 'upstream/main' into ColumnwiseNB
Jul 2, 2023
7257e77
Merge remote-tracking branch 'origin/ColumnwiseNB' into ColumnwiseNB
Jul 3, 2023
4ec0924
Merge branch 'scikit-learn:main' into ColumnwiseNB
avm19 Aug 15, 2023
96fdb12
Merge branch 'scikit-learn:main' into ColumnwiseNB
avm19 Aug 16, 2023
37a3341
Merge branch 'scikit-learn:main' into ColumnwiseNB
avm19 Aug 20, 2023
f62c3c3
Merge remote-tracking branch 'upstream/main' into ColumnwiseNB
Sep 11, 2023
e2d48d8
Move the changelog entry from 1.3 to 1.4
Sep 11, 2023
45e4381
Move ColumnwiseNB section before Out of Core section
Sep 12, 2023
95d00cd
Change versionadded from 1.3 to 1.4
Sep 12, 2023
a1ed149
Update sklearn/naive_bayes.py
avm19 Sep 12, 2023
6c09328
Update sklearn/naive_bayes.py
avm19 Sep 12, 2023
b0ccfb0
Update sklearn/naive_bayes.py
avm19 Sep 12, 2023
52c2265
Merge branch 'ColumnwiseNB' of github.com:avm19/scikit-learn into Col…
Sep 12, 2023
53fb1a9
Formatting
Sep 12, 2023
ddf5f55
Fix test re _parameter_constraints = {'nb_estimators': [list], ...}
Sep 12, 2023
78c70eb
Change nb_estimator to naive_bayes_estimator in all docstring
Sep 12, 2023
95c9c69
Rename 'nb_estimators' into 'estimators'
Sep 12, 2023
21e4fb8
Merge remote-tracking branch 'upstream/main' into ColumnwiseNB
Sep 12, 2023
3c3315c
Fix rename 'nb_estimators' into 'estimators'
Sep 12, 2023
ca78f52
Fix _fit_context decorator.
Sep 12, 2023
585d778
Merge remote-tracking branch 'upstream/main' into ColumnwiseNB
Sep 13, 2023
7d0ad34
Change _iter signature to mirror ColumnTransformer changes in #27005
Sep 13, 2023
919e2ca
Merge remote-tracking branch 'upstream/main' into ColumnwiseNB
Sep 13, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/modules/classes.rst
Expand Up @@ -1316,6 +1316,7 @@ Visualization
naive_bayes.ComplementNB
naive_bayes.GaussianNB
naive_bayes.MultinomialNB
naive_bayes.ColumnwiseNB
avm19 marked this conversation as resolved.
Show resolved Hide resolved


.. _neighbors_ref:
Expand Down
35 changes: 35 additions & 0 deletions doc/modules/naive_bayes.rst
Expand Up @@ -281,3 +281,38 @@ For an overview of available strategies in scikit-learn, see also the
The ``partial_fit`` method call of naive Bayes models introduces some
computational overhead. It is recommended to use data chunk sizes that are as
large as possible, that is as the available RAM allows.

.. _columnwise_naive_bayes:
avm19 marked this conversation as resolved.
Show resolved Hide resolved

Mix and match naive Bayes models
avm19 marked this conversation as resolved.
Show resolved Hide resolved
--------------------------------

A naive Bayes model that assumes different distribution families for different
features (or subsets of features) can be constructed using :class:`ColumnwiseNB`.
avm19 marked this conversation as resolved.
Show resolved Hide resolved
It is a meta-estimator, whose operation relies on naive Bayes
sub-estimators, which can be chosen in any number or combination from
:class:`GaussianNB`, :class:`MultinomialNB`, :class:`ComplementNB`,
:class:`BernoulliNB`, :class:`CategoricalNB`, and user-defined models
(provided they implement necessary methods).
avm19 marked this conversation as resolved.
Show resolved Hide resolved

When creating a :class:`ColumnwiseNB` estimator, one specifies sub-estimators
and their respective column subsets as a list of tuples.
Each sub-estimator is fitted and evaluated independently of the
others and "sees" only the features assigned to it. The predictions of sub-estimators are
combined via

.. math::

\log P(x,y)=\log P(x_{1},y) + \dots + \log P(x_{M},y) - (M - 1)\log P(y),

where :math:`\log P(x,y)` is the joint log-probability predicted by the meta-estimator,
:math:`\log P(x_{m},y)` is that by the :math:`m` th sub-estimator,
:math:`\log P(y)` is the class prior used by the meta-estimator, and
:math:`M\geq1` is the total number of sub-estimators.

See :ref:`sphx_glr_auto_examples_miscellaneous_plot_combining_naive_bayes.py`
for an example of a mixed naive Bayes model implementation.

See also :ref:`voting_classifier` for a way of combining general classifiers.
An introduction to processing datasets with heterogeneous features is available at
:ref:`column_transformer`.
4 changes: 4 additions & 0 deletions doc/whats_new/v1.3.rst
Expand Up @@ -582,6 +582,10 @@ Changelog

:mod:`sklearn.naive_bayes`
..........................
- |Feature| Add :class:`naive_bayes.ColumnwiseNB`, a meta-estimator that allows
avm19 marked this conversation as resolved.
Show resolved Hide resolved
existing naive Bayes classifiers to be combined and applied to different columns
of `X`. :pr:`22574` by :user:`Andrey Melnik <avm19>`.

avm19 marked this conversation as resolved.
Show resolved Hide resolved

- |Fix| :class:`naive_bayes.GaussianNB` does not raise anymore a `ZeroDivisionError`
when the provided `sample_weight` reduces the problem to a single class in `fit`.
Expand Down
134 changes: 134 additions & 0 deletions examples/miscellaneous/plot_combining_naive_bayes.py
@@ -0,0 +1,134 @@
"""
===================================================
Combining Naive Bayes Estimators using ColumnwiseNB
===================================================

.. currentmodule:: sklearn

This example shows how to use :class:`~naive_bayes.ColumnwiseNB`
meta-estimator to construct a naive Bayes model from base naive Bayes
estimators. The resulting model is applied to a dataset with a mixture of
discrete and continuous features.

We consider the titanic dataset, in which:

- numerical (continous) features "age" and "fare" are handled by
:class:`~naive_bayes.GaussianNB`;
- categorical (discrete) features "embarked", "sex", and "pclass" are handled
by :class:`~naive_bayes.CategoricalNB`.
"""

# Author: Andrey V. Melnik <andrey.melnik.maths@gmail.com>
# Pedro Morales <part.morales@gmail.com>
#
# License: BSD 3 clause

# %%
import pandas as pd
from sklearn import set_config
from sklearn.datasets import fetch_openml

set_config(transform_output="pandas")

X, y = fetch_openml(
"titanic", version=1, as_frame=True, return_X_y=True, n_retries=10, parser="auto"
)
X["pclass"] = X["pclass"].astype("category")
# Add a category for NaNs to the "embarked" feature:
X["embarked"] = X["embarked"].cat.add_categories("N/A").fillna("N/A")

# %%
# Build and use a pipeline around ``ColumnwiseNB``
# ------------------------------------------------

from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import OrdinalEncoder
from sklearn.naive_bayes import GaussianNB, CategoricalNB, ColumnwiseNB
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.metrics import accuracy_score

numeric_features = ["age", "fare"]
numeric_transformer = SimpleImputer(strategy="median")

categorical_features = ["embarked", "sex", "pclass"]
categories = [X[c].unique().to_list() for c in X[categorical_features]]
categorical_transformer = OrdinalEncoder(categories=categories)

preprocessor = ColumnTransformer(
transformers=[
("num", numeric_transformer, numeric_features),
("cat", categorical_transformer, categorical_features),
],
verbose_feature_names_out=False,
)

classifier = ColumnwiseNB(
nb_estimators=[
("gnb", GaussianNB(), numeric_features),
("cnb", CategoricalNB(), categorical_features),
]
)

pipe = Pipeline(steps=[("preprocessor", preprocessor), ("classifier", classifier)])
pipe
# %%
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=0)

pipe.fit(X_train, y_train)
y_pred = pipe.predict(X_test)
print(f"Test accuracy: {accuracy_score(y_test, y_pred)}")

# %%
# Compare choices of columns using ``GridSearchCV``
# --------------------------------------------------
#
# The allocation of columns to constituent subestimators can be regarded as a
# hyperparameter. We can explore the combinations of columns' choices and values
# of other hyperparameters with the help of :class:`~.model_selection.GridSearchCV`.

param_grid = {
"classifier__nb_estimators": [
[
("gnb", GaussianNB(), ["age", "fare"]),
("cnb", CategoricalNB(), categorical_features),
],
[("gnb", GaussianNB(), []), ("cnb", CategoricalNB(), ["pclass"])],
[("gnb", GaussianNB(), ["embarked"]), ("cnb", CategoricalNB(), [])],
],
"preprocessor__num__strategy": ["mean", "most_frequent"],
}

grid_search = GridSearchCV(pipe, param_grid, cv=10)
grid_search

# %%
# Calling `fit` triggers the cross-validated search for the best
# hyperparameters combination:

grid_search.fit(X_train, y_train)

print("Best params:")
print(grid_search.best_params_)

# %%
# As it turns out, the best results are achieved by the naive Bayes model when "sex"
# is the only feature used:

cv_results = pd.DataFrame(grid_search.cv_results_)
cv_results = cv_results.sort_values("mean_test_score", ascending=False)
cv_results["Columns dictionary"] = cv_results["param_classifier__nb_estimators"].map(
lambda l: {e[0]: e[-1] for e in l}
)
cv_results["'gnb' columns"] = cv_results["Columns dictionary"].map(lambda d: d["gnb"])
cv_results["'cnb' columns"] = cv_results["Columns dictionary"].map(lambda d: d["cnb"])
cv_results[
[
"mean_test_score",
"std_test_score",
"param_preprocessor__num__strategy",
"'gnb' columns",
"'cnb' columns",
]
]