-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #19 from simonthor/relbw
- Loading branch information
Showing
4 changed files
with
198 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
"""Tests for relativistic Breit-Wigner PDF.""" | ||
import pytest | ||
import tensorflow as tf | ||
import zfit | ||
|
||
# Important, do the imports below | ||
from zfit.core.testing import tester | ||
|
||
import zfit_physics as zphys | ||
|
||
# specify globals here. Do NOT add any TensorFlow but just pure python | ||
m_true = 125.0 | ||
gamma_true = 0.05 | ||
|
||
|
||
def create_relbw(m, gamma, limits): | ||
obs = zfit.Space("obs1", limits) | ||
relbw = zphys.pdf.RelativisticBreitWigner(m=m, gamma=gamma, obs=obs) | ||
return relbw, obs | ||
|
||
|
||
def test_relbw_pdf(): | ||
# Test PDF here | ||
relbw, _ = create_relbw(m_true, gamma_true, limits=(0, 200)) | ||
assert zfit.run(relbw.pdf(125.0)) == pytest.approx(12.732396211295313, rel=1e-4) | ||
assert relbw.pdf(tf.range(0.0, 200, 10_000)) <= relbw.pdf(125.0) | ||
|
||
sample = relbw.sample(1000) | ||
tf.debugging.assert_all_finite(sample, "Some samples from the relbw PDF are NaN or infinite") | ||
assert sample.n_events == 1000 | ||
assert all(tf.logical_and(0 <= sample, sample <= 200)) | ||
|
||
|
||
def test_relbw_integral(): | ||
# Test CDF and integral here | ||
relbw, obs = create_relbw(m_true, gamma_true, limits=(0, 200)) | ||
|
||
full_interval_analytic = zfit.run(relbw.analytic_integrate(obs, norm_range=False)) | ||
full_interval_numeric = zfit.run(relbw.numeric_integrate(obs, norm_range=False)) | ||
assert full_interval_analytic == pytest.approx(1.0, 1e-4) | ||
assert full_interval_numeric == pytest.approx(1.0, 2e-3) | ||
|
||
analytic_integral = zfit.run(relbw.analytic_integrate(limits=(50, 100), norm_range=False)) | ||
numeric_integral = zfit.run(relbw.numeric_integrate(limits=(50, 100), norm_range=False)) | ||
assert analytic_integral == pytest.approx(numeric_integral, 0.01) | ||
|
||
|
||
# register the pdf here and provide sets of working parameter configurations | ||
def relbw_params_factory(): | ||
m = zfit.Parameter("m", m_true) | ||
gamma = zfit.Parameter("gamma", gamma_true) | ||
return {"m": m, "gamma": gamma} | ||
|
||
|
||
tester.register_pdf(pdf_class=zphys.pdf.RelativisticBreitWigner, params_factories=relbw_params_factory) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
import numpy as np | ||
import tensorflow as tf | ||
import zfit | ||
from zfit.util import ztyping | ||
from zfit import z | ||
from zfit.core.space import ANY_LOWER, ANY_UPPER, Space | ||
|
||
|
||
@z.function(wraps='tensor') | ||
def relbw_pdf_func(x, m, gamma): | ||
""" | ||
Calculate the relativistic Breit-Wigner PDF. | ||
Args: | ||
x: value(s) for which the CDF will be calculated. | ||
m: Mean value | ||
gamma: width | ||
Returns: | ||
`tf.Tensor`: The calculated PDF values. | ||
Notes: | ||
Based on code from this [github gist](https://gist.github.com/andrewfowlie/cd0ed7e6c96f7c9e88f85eb3b9665b97#file-bw-py-L87-L110) | ||
""" | ||
x = z.unstack_x(x) | ||
alpha = gamma / m | ||
gamma2 = m ** 2 * (1.0 + alpha ** 2) ** 0.5 | ||
k = ( | ||
2.0 ** (3.0 / 2.0) | ||
* m ** 2 | ||
* alpha | ||
* gamma2 | ||
/ (np.pi * (m ** 2 + gamma2) ** 0.5) | ||
) | ||
|
||
return k / ((x ** 2 - m ** 2) ** 2 + m ** 4 * alpha ** 2) | ||
|
||
|
||
class RelativisticBreitWigner(zfit.pdf.ZPDF): | ||
"""Relativistic Breit-Wigner distribution. | ||
Formula for PDF and CDF are based on https://gist.github.com/andrewfowlie/cd0ed7e6c96f7c9e88f85eb3b9665b97 | ||
Args: | ||
m: the average value | ||
gamma: the width of the distribution | ||
""" | ||
|
||
_N_OBS = 1 | ||
_PARAMS = ["m", "gamma"] | ||
|
||
def _unnormalized_pdf(self, x: tf.Tensor) -> tf.Tensor: | ||
"""Calculate the PDF at value(s) x. | ||
Args: | ||
x : Either one value or an array | ||
Returns: | ||
`tf.Tensor`: The value(s) of the unnormalized PDF at x. | ||
""" | ||
return relbw_pdf_func(x, m=self.params['m'], gamma=self.params['gamma']) | ||
|
||
|
||
@z.function(wraps='tensor') | ||
def arctan_complex(x): | ||
r"""Function that evaluates arctan(x) using tensorflow but also supports complex numbers. | ||
It is defined as | ||
.. math:: | ||
\mathrm{arctan}(x) = \frac{i}{2} \left(\ln(1-ix) - \ln(1+ix)\right) | ||
Args: | ||
x: tf.Tensor | ||
Returns: | ||
.. math:: \mathrm{arctan}(x) | ||
Notes: | ||
Formula is taken from https://www.wolframalpha.com/input/?i=arctan%28a%2Bb*i%29 | ||
TODO: move somewhere? | ||
""" | ||
return 1 / 2 * 1j * (tf.math.log(1 - 1j * x) - tf.math.log(1 + 1j * x)) | ||
|
||
|
||
@z.function(wraps='tensor') | ||
def relbw_cdf_func(x, m, gamma): | ||
""" | ||
Analytical function for the CDF of the relativistic Breit-Wigner distribution. | ||
Args: | ||
x: value(s) for which the CDF will be calculated. | ||
m: Mean value | ||
gamma: width | ||
Returns: | ||
`tf.Tensor`: The calculated CDF values. | ||
Notes: | ||
Based on code from this [github gist](https://gist.github.com/andrewfowlie/cd0ed7e6c96f7c9e88f85eb3b9665b97#file-bw-py-L112-L154) | ||
""" | ||
gamma = z.to_complex(gamma) | ||
m = z.to_complex(m) | ||
x = z.to_complex(z.unstack_x(x)) | ||
|
||
alpha = gamma / m | ||
gamma2 = m ** 2 * (1.0 + alpha ** 2) ** 0.5 | ||
k = 2.0 ** (3.0 / 2.0) * m ** 2 * alpha * gamma2 / (np.pi * (m ** 2 + gamma2) ** 0.5) | ||
|
||
arg_1 = z.to_complex(-1) ** (1.0 / 4.0) / (-1j + alpha) ** 0.5 * x / m | ||
arg_2 = z.to_complex(-1) ** (3.0 / 4.0) / (1j + alpha) ** 0.5 * x / m | ||
|
||
shape = -1j * arctan_complex(arg_1) / (-1j + alpha) ** 0.5 - arctan_complex(arg_2) / (1j + alpha) ** 0.5 | ||
norm = z.to_complex(-1) ** (1.0 / 4.0) * k / (2.0 * alpha * m ** 3) | ||
|
||
cdf_ = shape * norm | ||
cdf_ = z.to_real(cdf_) | ||
return cdf_ | ||
|
||
|
||
def relbw_integral(limits: ztyping.SpaceType, params: dict, model) -> tf.Tensor: | ||
""" | ||
Calculates the analytic integral of the relativistic Breit-Wigner PDF. | ||
Args: | ||
limits: An object with attribute rect_limits. | ||
params: A hashmap from which the parameters that defines the PDF will be extracted. | ||
model: Will be ignored. | ||
Returns: | ||
The calculated integral. | ||
""" | ||
lower, upper = limits.rect_limits | ||
lower_cdf = relbw_cdf_func(x=lower, m=params["m"], gamma=params["gamma"]) | ||
upper_cdf = relbw_cdf_func(x=upper, m=params["m"], gamma=params["gamma"]) | ||
return upper_cdf - lower_cdf | ||
|
||
|
||
# These lines of code adds the analytic integral function to RelativisticBreitWigner PDF. | ||
relbw_integral_limits = Space(axes=(0,), limits=(((ANY_LOWER,),), ((ANY_UPPER,),))) | ||
RelativisticBreitWigner.register_analytic_integral(func=relbw_integral, limits=relbw_integral_limits) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
from .models.pdf_argus import Argus | ||
from .models.pdf_relbw import RelativisticBreitWigner | ||
|
||
__all__ = ["Argus"] | ||
__all__ = ["Argus", "RelativisticBreitWigner"] |