Skip to content

Commit

Permalink
add binned NLL
Browse files Browse the repository at this point in the history
  • Loading branch information
jonas-eschle committed Oct 23, 2020
1 parent 062757f commit 0dad0a2
Show file tree
Hide file tree
Showing 7 changed files with 118 additions and 5 deletions.
37 changes: 37 additions & 0 deletions tests/binnednll.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright (c) 2020 zfit

import boost_histogram as bh
import numpy as np

import zfit
from zfit.core.binneddata import BinnedData
from zfit.core.binning import RectBinning
from zfit.models.binned_functor import BinnedSumPDF
from zfit.models.template import BinnedTemplatePDF


def test_binned_nll_simple():
counts = np.random.uniform(high=1, size=(10, 20)) # generate counts
counts2 = np.random.normal(loc=5, size=(10, 20))
counts3 = np.linspace(0, 10, num=10)[:, None] * np.linspace(0, 5, num=20)[None, :]
binnings = [bh.axis.Regular(10, 0, 10), bh.axis.Regular(20, -10, 30)]
binning = RectBinning(binnings=binnings)
obs = zfit.Space(obs=['obs1', 'obs2'], binning=binning)

mc1 = BinnedData.from_numpy(obs=obs, counts=counts, w2error=10)
mc2 = BinnedData.from_numpy(obs=obs, counts=counts2, w2error=10)
mc3 = BinnedData.from_numpy(obs=obs, counts=counts3, w2error=10)

observed_data = BinnedData.from_numpy(obs=obs, counts=counts + counts2 + counts3, w2error=10)

pdf = BinnedTemplatePDF(data=mc1)
pdf2 = BinnedTemplatePDF(data=mc2)
pdf3 = BinnedTemplatePDF(data=mc3)
pdf.set_yield(np.sum(counts))
pdf2.set_yield(np.sum(counts2))
pdf3.set_yield(np.sum(counts3))
# assert len(pdf.ext_pdf(None)) > 0
pdf_sum = BinnedSumPDF(pdfs=[pdf, pdf2, pdf3], obs=obs)

nll = zfit.loss.ExtendedBinnedNLL(pdf_sum, data=observed_data)
nll.value()
1 change: 1 addition & 0 deletions zfit/_loss/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Copyright (c) 2020 zfit
31 changes: 31 additions & 0 deletions zfit/_loss/binnedloss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Copyright (c) 2020 zfit
from typing import Iterable

from .. import z
from ..core.interfaces import ZfitBinnedPDF, ZfitBinnedData
from ..core.loss import BaseLoss
from ..util import ztyping

import tensorflow as tf


class ExtendedBinnedNLL(BaseLoss):

def __init__(self, model: ztyping.ModelsInputType, data: ztyping.DataInputType,
constraints: ztyping.ConstraintsTypeInput = None):
super().__init__(model=model, data=data, constraints=constraints, fit_range=None)

@z.function(wraps='loss')
def _loss_func(self, model: Iterable[ZfitBinnedPDF], data: Iterable[ZfitBinnedData],
fit_range, constraints):
poisson_terms = []
for mod, dat in zip(model, data):
poisson_terms.append(tf.nn.log_poisson_loss(dat.get_counts(obs=mod.obs),
tf.math.log(mod.ext_pdf(None)))) # TODO: change None
nll = tf.reduce_sum(poisson_terms, axis=0)

if constraints:
constraints = z.reduce_sum([c.value() for c in constraints])
nll += constraints

return nll
9 changes: 8 additions & 1 deletion zfit/core/binneddata.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .dimension import BaseDimensional
from .interfaces import ZfitBinnedData
from .. import z
from ..util.exception import WorkInProgressError
from ..util.ztyping import NumericalTypeReturn


Expand All @@ -28,10 +29,16 @@ def from_numpy(cls, obs, counts, w2error, name=None):
def _input_check_counts(self, counts): # TODO
return counts

def get_counts(self, bins=None) -> NumericalTypeReturn:
def get_counts(self, bins=None, obs=None) -> NumericalTypeReturn:
if bins is not None:
raise WorkInProgressError
if obs is not None and not obs == self.obs:
raise WorkInProgressError("Currently, reordering of axes not supported")
return self._counts

def weights_error_squared(self) -> NumericalTypeReturn:
return self._w2error

@property
def data_range(self):
return self.space
35 changes: 33 additions & 2 deletions zfit/core/binnedpdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
from .baseobject import BaseNumeric
from .dimension import BaseDimensional
from .interfaces import ZfitBinnedPDF, ZfitSpace, ZfitParameter
from .. import convert_to_parameter
from .. import convert_to_parameter, convert_to_space
from ..util import ztyping
from ..util.cache import GraphCachable
from ..util.exception import SpecificFunctionNotImplementedError, WorkInProgressError, NotExtendedPDFError, \
AlreadyExtendedPDFError
AlreadyExtendedPDFError, SpaceIncompatibleError


class BaseBinnedPDF(BaseNumeric, GraphCachable, BaseDimensional, ZfitBinnedPDF):
Expand Down Expand Up @@ -163,3 +163,34 @@ def _fallback_sample(self, n, limits):

def _copy(self, deep, name, overwrite_params):
raise WorkInProgressError

# factor out with unbinned pdf
@property
def norm_range(self):
return self._norm_range

# TODO: factor out with unbinned pdf
def convert_sort_space(self, obs: Union[ztyping.ObsTypeInput, ztyping.LimitsTypeInput] = None,
axes: ztyping.AxesTypeInput = None,
limits: ztyping.LimitsTypeInput = None) -> Union[ZfitSpace, None]:
"""Convert the inputs (using eventually `obs`, `axes`) to :py:class:`~zfit.ZfitSpace` and sort them according to
own `obs`.
Args:
obs:
axes:
limits:
Returns:
"""
if obs is None: # for simple limits to convert them
obs = self.obs
elif not set(obs).intersection(self.obs):
raise SpaceIncompatibleError("The given space {obs} is not compatible with the obs of the pdfs{self.obs};"
" they are disjoint.")
space = convert_to_space(obs=obs, axes=axes, limits=limits)

if self.space is not None: # e.g. not the first call
space = space.with_coords(self.space, allow_superset=True, allow_subset=True)
return space
2 changes: 1 addition & 1 deletion zfit/core/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -938,7 +938,7 @@ def value(self):
class ZfitBinnedData(ZfitDimensional):

@abstractmethod
def get_counts(self, bins):
def get_counts(self, bins, obs):
raise NotImplementedError

@abstractmethod
Expand Down
8 changes: 7 additions & 1 deletion zfit/loss.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
# Copyright (c) 2020 zfit
import warnings

from ._loss.binnedloss import ExtendedBinnedNLL
from .core.loss import ExtendedUnbinnedNLL, UnbinnedNLL, BaseLoss, SimpleLoss

__all__ = ['ExtendedUnbinnedNLL', "UnbinnedNLL", "BaseLoss", "SimpleLoss", 'experimental_enable_loss_penalty']
__all__ = ["ExtendedUnbinnedNLL",
"UnbinnedNLL",
"BaseLoss",
"SimpleLoss",
"experimental_enable_loss_penalty",
"ExtendedBinnedNLL"]

from .util.warnings import warn_experimental_feature

Expand Down

0 comments on commit 0dad0a2

Please sign in to comment.