Skip to content

Commit

Permalink
ENH: Allow custom bandwidth functions in KDEUnivariate fit
Browse files Browse the repository at this point in the history
  • Loading branch information
dbivolaru committed Aug 26, 2020
1 parent 1e6b7c6 commit 6ce3e5b
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 2 deletions.
10 changes: 8 additions & 2 deletions statsmodels/nonparametric/kde.py
Expand Up @@ -344,7 +344,10 @@ def kdensity(X, kernel="gau", bw="normal_reference", weights=None, gridsize=None

# if bw is None, select optimal bandwidth for kernel
try:
bw = float(bw)
if callable(bw):
bw = float(bw(X, kern))
else:
bw = float(bw)
except:
bw = bandwidths.select_bandwidth(X, bw, kern)
bw *= adjust
Expand Down Expand Up @@ -454,7 +457,10 @@ def kdensityfft(X, kernel="gau", bw="normal_reference", weights=None, gridsize=N
kern = kernel_switch[kernel]()

try:
bw = float(bw)
if callable(bw):
bw = bw(X, kern) # user passed a callable custom bandwidth function
else:
bw = float(bw)
except:
bw = bandwidths.select_bandwidth(X, bw, kern) # will cross-val fit this pattern?
bw *= adjust
Expand Down
16 changes: 16 additions & 0 deletions statsmodels/nonparametric/tests/test_kde.py
Expand Up @@ -348,3 +348,19 @@ def test_fit_self(reset_randomstate):
kde = KDE(x)
assert isinstance(kde, KDE)
assert isinstance(kde.fit(), KDE)


class TestKDECustomBandwidth(object):

@classmethod
def setup_class(cls):
cls.kde = KDE(Xi)
cls.weights_200 = np.linspace(1, 100, 200)
cls.weights_100 = np.linspace(1, 100, 100)

def test_check_is_fit_ok_with_custom_bandwidth(self):
def custom_bw(X, kern):
return np.std(X) * len(X)
kde = self.kde.fit(bw=custom_bw)
assert isinstance(kde, KDE)

0 comments on commit 6ce3e5b

Please sign in to comment.