Skip to content

Commit

Permalink
ISJ bandwidth computation now respects the weights (#71)
Browse files Browse the repository at this point in the history
* removes redundant callable checks

* adds weights to bandwidth API and implements them for ISJ

* uses weights in bandwidth calculation

* adds pycharm files to gitignore

* remove unused code

* add basic tests and fix for weights <= 0

* adds test for bandwidth weight being the same as resampling

* force tests to be lowercase functions

* add parameters to docstring

Co-authored-by: Tommy <10076072+tommyod@users.noreply.github.com>
  • Loading branch information
lukedyer-peak and tommyod committed Nov 20, 2020
1 parent 698c8aa commit 06408d3
Show file tree
Hide file tree
Showing 8 changed files with 79 additions and 15 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,5 @@ ENV/
# mypy
.mypy_cache/
.pytest_cache/

.idea/
2 changes: 1 addition & 1 deletion KDEpy/BaseKDE.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def evaluate(self, grid_points=None, bw_to_scalar=True):
else:
bw = self.bw
elif callable(self.bw):
bw = self.bw(self.data)
bw = self.bw(self.data, self.weights)
else:
bw = self.bw
self.bw = bw
Expand Down
6 changes: 2 additions & 4 deletions KDEpy/FFTKDE.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,12 +139,10 @@ def evaluate(self, grid_points=None):
if not grid_is_sorted(self.grid_points):
raise ValueError("The grid must be sorted.")

if callable(self.bw):
bw = self.bw(self.data)
elif isinstance(self.bw, numbers.Number) and self.bw > 0:
if isinstance(self.bw, numbers.Number) and self.bw > 0:
bw = self.bw
else:
raise ValueError("The bw must be a callable or a number.")
raise ValueError("The bw must be a number.")
self.bw = bw

# Step 0 - Make sure data points are inside of the grid
Expand Down
2 changes: 0 additions & 2 deletions KDEpy/NaiveKDE.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,6 @@ def evaluate(self, grid_points=None):
bw = self.bw
if isinstance(bw, numbers.Number):
bw = np.asfarray(np.ones(self.data.shape[0]) * bw)
elif callable(bw):
bw = np.asfarray(np.ones(self.data.shape[0]) * bw(self.data))

# TODO: Implementation w.r.t grid points for faster evaluation
# See the SciPy evaluation for how this can be done
Expand Down
2 changes: 0 additions & 2 deletions KDEpy/TreeKDE.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,6 @@ def evaluate(self, grid_points=None, eps=10e-4):
bw = self.bw
if isinstance(bw, numbers.Number):
bw = np.asfarray(np.ones(obs) * bw)
elif callable(bw):
bw = np.asfarray(np.ones(obs) * bw(self.data))
else:
bw = np.asarray_chkfinite(bw, dtype=np.float)

Expand Down
28 changes: 24 additions & 4 deletions KDEpy/bw_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def _root(function, N, args):
return x


def improved_sheather_jones(data):
def improved_sheather_jones(data, weights=None):
"""
The Improved Sheater Jones (ISJ) algorithm from the paper by Botev et al.
This algorithm computes the optimal bandwidth for a gaussian kernel,
Expand All @@ -136,12 +136,26 @@ def improved_sheather_jones(data):
sheather+jones+why+use+dct&source=bl&ots=1ETdKd_6EF&sig=jZk4R515GB1xsn-
VZVnjr-JfjSI&hl=en&sa=X&ved=2ahUKEwi1_czNncTcAhVGhqYKHaPiBtcQ6AEwA3oEC
AcQAQ#v=onepage&q=sheather%20jones%20why%20use%20dct&f=false
Parameters
----------
data: array-like
The data points. Data must have shape (obs, 1).
weights: array-like, optional
One weight per data point. Must have shape (obs,). If None is
passed, uniform weights are used.
"""
obs, dims = data.shape
if not dims == 1:
raise ValueError("ISJ is only available for 1D data.")

n = 2 ** 10

# weights <= 0 still affect calculations unless we remove them
if weights is not None:
data = data[weights > 0]
weights = weights[weights > 0]

# Setting `percentile` higher decreases the chance of overflow
xmesh = autogrid(data, boundary_abs=6, num_points=n, boundary_rel=0.5)
data = data.ravel()
Expand All @@ -155,7 +169,7 @@ def improved_sheather_jones(data):

# Use linear binning to bin the data on an equidistant grid, this is a
# prerequisite for using the FFT (evenly spaced samples)
initial_data = linear_binning(data.reshape(-1, 1), xmesh)
initial_data = linear_binning(data.reshape(-1, 1), xmesh, weights)
assert np.allclose(initial_data.sum(), 1)

# Compute the type 2 Discrete Cosine Transform (DCT) of the data
Expand Down Expand Up @@ -187,7 +201,7 @@ def improved_sheather_jones(data):
return bandwidth


def scotts_rule(data):
def scotts_rule(data, weights=None):
"""
Scotts rule.
Expand All @@ -204,6 +218,9 @@ def scotts_rule(data):
if not len(data.shape) == 2:
raise ValueError("Data must be of shape (obs, dims).")

if weights is not None:
warnings.warn("Scott's rule currently ignores all weights")

obs, dims = data.shape
if not dims == 1:
raise ValueError("Scotts rule is only available for 1D data.")
Expand All @@ -215,7 +232,7 @@ def scotts_rule(data):
return sigma * np.power(obs, -1.0 / (dims + 4))


def silvermans_rule(data):
def silvermans_rule(data, weights=None):
"""
Returns optimal smoothing (standard deviation) if the data is close to
normal.
Expand All @@ -236,6 +253,9 @@ def silvermans_rule(data):
if not dims == 1:
raise ValueError("Silverman's rule is only available for 1D data.")

if weights is not None:
warnings.warn("Silverman's rule currently ignores all weights")

if obs == 1:
return 1
if obs < 1:
Expand Down
2 changes: 0 additions & 2 deletions KDEpy/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,6 @@ def test_api_types(kde, bw, kernel, type_func):
Test the API. Data and weights may be passed as tuples, arrays, lists, etc.
"""
# Test various input types
data = [1, 2, 3]
weights = [4, 5, 6]
data = np.random.randn(64)
weights = np.random.randn(64) + 10
model = kde(kernel=kernel, bw=bw)
Expand Down
50 changes: 50 additions & 0 deletions KDEpy/tests/test_bw_selection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Tests for the bandwidth selection.
"""

import pytest
import numpy as np

from KDEpy.bw_selection import _bw_methods, improved_sheather_jones


@pytest.fixture(scope="module")
def data() -> np.ndarray:
return np.random.randn(100, 1)


@pytest.mark.parametrize("method", _bw_methods.values())
def test_equal_weights_dont_changed_bw(data, method):
weights = np.ones_like(data).squeeze() * 2
bw_no_weights = method(data, weights=None)
bw_weighted = method(data, weights=weights)
np.testing.assert_almost_equal(bw_no_weights, bw_weighted)


def test_isj_bw_weights_single_zero_weighted_point(data):
data_with_outlier = np.concatenate((data.copy(), np.array([[1000]])))
weights = np.ones_like(data_with_outlier).squeeze()
weights[-1] = 0

np.testing.assert_array_almost_equal(
improved_sheather_jones(data),
improved_sheather_jones(data_with_outlier, weights),
)


# multiple runs to allow a good spread of catching errors
@pytest.mark.parametrize("execution_number", range(5))
def test_isj_bw_weights_same_as_resampling(data, execution_number):
sample_weights = np.random.randint(low=1, high=100, size=len(data))
data_resampled = np.repeat(data, repeats=sample_weights).reshape((-1, 1))
np.testing.assert_array_almost_equal(
improved_sheather_jones(data_resampled),
improved_sheather_jones(data, sample_weights),
)


if __name__ == "__main__":
# --durations=10 <- May be used to show potentially slow tests
pytest.main(args=[__file__, "--doctest-modules", "-v", "--durations=15"])

0 comments on commit 06408d3

Please sign in to comment.