Skip to content

Commit

Permalink
Merge pull request #230 from zfit/ext_methods
Browse files Browse the repository at this point in the history
Add ext_method
  • Loading branch information
jonas-eschle committed Apr 21, 2020
2 parents 6bc7429 + d634892 commit 10c29b3
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 5 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ Breaking changes
To extract limits from multiple limits, `MultiSpace` and `Space` are both iterables, returning
the containing spaces respectively itself (for the `Space` case).
- SumPDF changed in the behavior. Read above in the Major Features and Improvement.
- Integrals of extended PDFs are not extended anymore.
- Integrals of extended PDFs are not extended anymore, but `ext_integrate` now returns the
integral multiplied by the yield.

Depreceations
-------------
Expand Down
11 changes: 7 additions & 4 deletions tests/test_basePDF.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,10 +162,13 @@ def test_normalization(pdf_factory):
assert probs == pytest.approx(1., rel=0.05)
assert log_probs == pytest.approx(tf.math.log(probs_small).numpy(), rel=0.05)
dist = dist.create_extended(z.constant(test_yield))
probs_extended = dist.pdf(samples)
result_extended = probs_extended.numpy()
result_extended = np.average(result_extended) * (high - low)
assert result_extended == pytest.approx(1, rel=0.05)
probs = dist.pdf(samples)
probs_extended = dist.ext_pdf(samples)
result = probs.numpy()
result = np.average(result) * (high - low)
result_ext = np.average(probs_extended) * (high - low)
assert result == pytest.approx(1, rel=0.05)
assert result_ext == pytest.approx(test_yield, rel=0.05)


@pytest.mark.parametrize('gauss_factory', [create_gauss1, create_test_gauss1])
Expand Down
2 changes: 2 additions & 0 deletions tests/test_pdfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,9 @@ def test_extended_gauss():
gauss_dists = [gauss1, gauss2, gauss3]

sum_gauss = SumPDF(pdfs=gauss_dists)
integral_true = sum_gauss.integrate((-1, 5)) * sum_gauss.get_yield()

assert zfit.run(integral_true) == pytest.approx(zfit.run(sum_gauss.ext_integrate((-1, 5))))
normalization_testing(pdf=sum_gauss, limits=obs1)


Expand Down
48 changes: 48 additions & 0 deletions zfit/core/basepdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,36 @@ def _call_unnormalized_pdf(self, x):
# "it received on initialization."
# "Original Error: {}".format(error))

@z.function(wraps='model')
def ext_pdf(self, x: ztyping.XTypeInput, norm_range: ztyping.LimitsTypeInput = None) -> ztyping.XType:
"""Probability density function scaled by yield, normalized over `norm_range`.
Args:
x (numerical): `float` or `double` `Tensor`.
norm_range (tuple, :py:class:`~zfit.Space`): :py:class:`~zfit.Space` to normalize over
Returns:
:py:class:`tf.Tensor` of type `self.dtype`.
"""
if not self.is_extended:
raise NotExtendedPDFError(f"{self} is not extended, cannot call `ext_pdf`")
return self.pdf(x=x, norm_range=norm_range) * self.get_yield()

@z.function(wraps='model')
def ext_log_pdf(self, x: ztyping.XTypeInput, norm_range: ztyping.LimitsTypeInput = None) -> ztyping.XType:
"""Log of probability density function scaled by yield, normalized over `norm_range`.
Args:
x (numerical): `float` or `double` `Tensor`.
norm_range (tuple, :py:class:`~zfit.Space`): :py:class:`~zfit.Space` to normalize over
Returns:
:py:class:`tf.Tensor` of type `self.dtype`.
"""
if not self.is_extended:
raise NotExtendedPDFError(f"{self} is not extended, cannot call `ext_pdf`")
return self.log_pdf(x=x, norm_range=norm_range) + tf.math.log(self.get_yield())

@_BasePDF_register_check_support(False)
def _pdf(self, x, norm_range):
raise SpecificFunctionNotImplementedError
Expand Down Expand Up @@ -370,6 +400,24 @@ def _fallback_log_pdf(self, x, norm_range):
def gradients(self, x: ztyping.XType, norm_range: ztyping.LimitsType, params: ztyping.ParamsTypeOpt = None):
raise BreakingAPIChangeError("Removed with 0.5.x: is this needed?")

@z.function(wraps='model')
def ext_integrate(self, limits: ztyping.LimitsType, norm_range: ztyping.LimitsType = None) -> ztyping.XType:
"""Integrate the function over `limits` (normalized over `norm_range` if not False).
Args:
limits (tuple, :py:class:`~zfit.ZfitSpace`): the limits to integrate over
norm_range (tuple, :py:class:`~zfit.ZfitSpace`): the limits to normalize over or False to integrate the
unnormalized probability
Returns:
:py:class`tf.Tensor`: the integral value as a scalar with shape ()
"""
norm_range = self._check_input_norm_range(norm_range)
limits = self._check_input_limits(limits=limits)
if not self.is_extended:
raise NotExtendedPDFError(f"{self} is not extended, cannot call `ext_pdf`")
return self.integrate(limits=limits, norm_range=norm_range) * self.get_yield()

def _apply_yield(self, value: float, norm_range: ztyping.LimitsType, log: bool) -> Union[float, tf.Tensor]:
if self.is_extended and not norm_range.limits_are_false:
if log:
Expand Down

0 comments on commit 10c29b3

Please sign in to comment.