Skip to content

Commit

Permalink
Add Truncated kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
jonas-eschle committed Jun 28, 2020
1 parent 05821df commit f50d520
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 16 deletions.
24 changes: 17 additions & 7 deletions tests/test_pdf_kde.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,26 @@ def test_simple_kde():

size = 5000
data = np.random.normal(size=size, loc=2, scale=3)

limits = (-15, 5)
obs = zfit.Space("obs1", limits=limits)
data_truncated = obs.filter(data)
# data = np.concatenate([data, np.random.uniform(size=size * 1, low=-5, high=2.3)])
# data = tf.random.poisson(shape=(13000,), lam=7, dtype=ztypes.float)
limits = (-15, 5)
kde = zfit.models.kde.GaussianKDE1DimExactV1(data=data, bandwidth=h, obs=zfit.Space("obs1", limits=limits))
kde_adaptive = zfit.models.kde.GaussianKDE1DimExactV1(data=data, bandwidth='adaptiveV1',
obs=zfit.Space("obs1", limits=limits))
kde_silverman = zfit.models.kde.GaussianKDE1DimExactV1(data=data, bandwidth='silverman',
obs=zfit.Space("obs1", limits=limits))
kde = zfit.models.kde.GaussianKDE1DimV1(data=data, bandwidth=h, obs=obs,
truncate=False)
kde_adaptive = zfit.models.kde.GaussianKDE1DimV1(data=data, bandwidth='adaptiveV1',
obs=obs,
truncate=False)
kde_silverman = zfit.models.kde.GaussianKDE1DimV1(data=data, bandwidth='silverman',
obs=obs,
truncate=False)
kde_adaptive_trunc = zfit.models.kde.GaussianKDE1DimV1(data=data_truncated, bandwidth='adaptiveV1',
obs=obs,
truncate=True)

integral = kde.integrate(limits=limits, norm_range=False)
integral_trunc = kde_adaptive_trunc.integrate(limits=limits, norm_range=False)
integral_adaptive = kde_adaptive.integrate(limits=limits, norm_range=False)
integral_silverman = kde_silverman.integrate(limits=limits, norm_range=False)

Expand All @@ -40,7 +50,7 @@ def test_simple_kde():
# plt.show()

rel_tol = 0.04
assert zfit.run(integral) == pytest.approx(expected_integral, rel=rel_tol)
assert zfit.run(integral_trunc) == pytest.approx(1., rel=rel_tol)
assert zfit.run(integral) == pytest.approx(expected_integral, rel=rel_tol)
assert zfit.run(integral_adaptive) == pytest.approx(expected_integral, rel=rel_tol)
assert zfit.run(integral_silverman) == pytest.approx(expected_integral, rel=rel_tol)
35 changes: 26 additions & 9 deletions zfit/models/kde.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from .dist_tfp import WrapDistribution
from .. import z, ztypes
from ..core.interfaces import ZfitData
from ..core.interfaces import ZfitData, ZfitSpace
from ..util import ztyping
from ..util.exception import OverdefinedError, ShapeIncompatibleError

Expand All @@ -34,9 +34,9 @@ def bandwidth_adaptiveV1(data, func):
return bandwidth


def _adaptive_bandwidth_KDEV1(constructor, obs, data, weights, name):
def _adaptive_bandwidth_KDEV1(constructor, obs, data, weights, name, truncate):
kde_silverman = constructor(obs=obs, data=data, bandwidth='silverman', weights=weights,
name=f"INTERNAL_{name}")
name=f"INTERNAL_{name}", truncate=truncate)
return bandwidth_adaptiveV1(data=data, func=kde_silverman.pdf)


Expand All @@ -53,7 +53,7 @@ def min_std_or_iqr(x):
return tf.minimum(tf.math.reduce_std(x), (tfp.stats.percentile(x, 75) - tfp.stats.percentile(x, 25)))


class GaussianKDE1DimExactV1(WrapDistribution):
class GaussianKDE1DimV1(WrapDistribution):
_N_OBS = 1

_bandwidth_methods = {
Expand All @@ -64,7 +64,8 @@ class GaussianKDE1DimExactV1(WrapDistribution):

def __init__(self, obs: ztyping.ObsTypeInput, data: ztyping.ParamTypeInput,
bandwidth: ztyping.ParamTypeInput = None,
weights: Union[None, np.ndarray, tf.Tensor] = None, name: str = "GaussianKDE1DimV1"):
weights: Union[None, np.ndarray, tf.Tensor] = None, truncate: bool = True,
name: str = "GaussianKDE1DimV1"):
r"""One dimensional Kernel Density Estimation with a Gaussian Kernel.
Kernel Density Estimation is a non-parametric method to approximate the density of given points.
Expand All @@ -84,10 +85,12 @@ def __init__(self, obs: ztyping.ObsTypeInput, data: ztyping.ParamTypeInput,
Args:
data: 1-D Tensor-like. The positions of the `kernel`, the :math:`x_i`. Determines how many kernels will be created.
bandwidth: Broadcastable to the batch and event shape of the distribution. A scalar will simply broadcast
bandwidth: Bandwidth of the kernel. Broadcastable to the batch and event shape of the distribution. A scalar will simply broadcast
to `data` for a 1-D distribution.
obs: Observables
weights: Weights of each `data`, can be None or Tensor-like with shape compatible with `data`
truncate: If a truncated Gaussian kernel should be used with the limits given by the `obs` lower and
upper limits. This can cause NaNs in case datapoints are outside of the limits.
name: Name of the PDF
"""
if bandwidth is None:
Expand Down Expand Up @@ -121,15 +124,28 @@ def __init__(self, obs: ztyping.ObsTypeInput, data: ztyping.ParamTypeInput,
raise ValueError(f"Cannot use {bandwidth} as a bandwidth method. Use a numerical value or one of"
f" the defined methods: {list(self._bandwidth_methods.keys())}")
bandwidth = bw_method(constructor=type(self), obs=obs, data=data, weights=weights,
name=f"INTERNAL_{name}")
name=f"INTERNAL_{name}", truncate=truncate)

bandwidth_param = -999 if bandwidth_param == 'adaptiveV1' else bandwidth # TODO: multiparam for bandwidth?

params = {'bandwidth': bandwidth_param}

# create distribution factory
def kernel_factory():
return tfp.distributions.Normal(loc=self._data, scale=self._bandwidth)
if truncate:
if not isinstance(obs, ZfitSpace):
raise ValueError(f"`obs` has to be a `ZfitSpace` if `truncated` is True.")
inside = obs.inside(data)
all_inside = tf.reduce_all(inside)
tf.debugging.assert_equal(all_inside, True, message="Not all data points are inside the limits but"
" a truncate kernel was chosen.")

def kernel_factory():
return tfp.distributions.TruncatedNormal(loc=self._data, scale=self._bandwidth,
low=self.space.rect_lower,
high=self.space.rect_upper)
else:
def kernel_factory():
return tfp.distributions.Normal(loc=self._data, scale=self._bandwidth)

dist_kwargs = lambda: dict(mixture_distribution=categorical,
components_distribution=kernel_factory())
Expand All @@ -145,3 +161,4 @@ def kernel_factory():
self._data_weights = weights
self._bandwidth = bandwidth
self._data = data
self._truncate = truncate

0 comments on commit f50d520

Please sign in to comment.