From 6ce3e5b2ab73ead6371af108e4641b262e023858 Mon Sep 17 00:00:00 2001 From: Dorian Bivolaru Date: Wed, 26 Aug 2020 17:06:10 +0900 Subject: [PATCH] ENH: Allow custom bandwidth functions in KDEUnivariate fit --- statsmodels/nonparametric/kde.py | 10 ++++++++-- statsmodels/nonparametric/tests/test_kde.py | 16 ++++++++++++++++ 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/statsmodels/nonparametric/kde.py b/statsmodels/nonparametric/kde.py index 9cca87cd866..641d3e33596 100644 --- a/statsmodels/nonparametric/kde.py +++ b/statsmodels/nonparametric/kde.py @@ -344,7 +344,10 @@ def kdensity(X, kernel="gau", bw="normal_reference", weights=None, gridsize=None # if bw is None, select optimal bandwidth for kernel try: - bw = float(bw) + if callable(bw): + bw = float(bw(X, kern)) + else: + bw = float(bw) except: bw = bandwidths.select_bandwidth(X, bw, kern) bw *= adjust @@ -454,7 +457,10 @@ def kdensityfft(X, kernel="gau", bw="normal_reference", weights=None, gridsize=N kern = kernel_switch[kernel]() try: - bw = float(bw) + if callable(bw): + bw = bw(X, kern) # user passed a callable custom bandwidth function + else: + bw = float(bw) except: bw = bandwidths.select_bandwidth(X, bw, kern) # will cross-val fit this pattern? bw *= adjust diff --git a/statsmodels/nonparametric/tests/test_kde.py b/statsmodels/nonparametric/tests/test_kde.py index 2b83af65767..6999e20324c 100644 --- a/statsmodels/nonparametric/tests/test_kde.py +++ b/statsmodels/nonparametric/tests/test_kde.py @@ -348,3 +348,19 @@ def test_fit_self(reset_randomstate): kde = KDE(x) assert isinstance(kde, KDE) assert isinstance(kde.fit(), KDE) + + +class TestKDECustomBandwidth(object): + + @classmethod + def setup_class(cls): + cls.kde = KDE(Xi) + cls.weights_200 = np.linspace(1, 100, 200) + cls.weights_100 = np.linspace(1, 100, 100) + + def test_check_is_fit_ok_with_custom_bandwidth(self): + def custom_bw(X, kern): + return np.std(X) * len(X) + kde = self.kde.fit(bw=custom_bw) + assert isinstance(kde, KDE) +