diff --git a/binder/environment.yml b/binder/environment.yml new file mode 100644 index 0000000000..7ba02c3af7 --- /dev/null +++ b/binder/environment.yml @@ -0,0 +1,111 @@ +name: pyhf +channels: +- pytorch +- https://conda.anaconda.org/NLeSC +- defaults +dependencies: +- ca-certificates=2017.08.26=h1d4fec5_0 +- certifi=2018.1.18=py36_0 +- cffi=1.11.4=py36h9745a5d_0 +- cudatoolkit=8.0=3 +- cudnn=7.0.5=cuda8.0_0 +- freetype=2.8=hab7d2ae_1 +- intel-openmp=2018.0.0=hc7b2577_8 +- jpeg=9b=h024ee3a_2 +- libedit=3.1=heed3624_0 +- libffi=3.2.1=hd88cf55_4 +- libgcc-ng=7.2.0=h7cc24e2_2 +- libgfortran-ng=7.2.0=h9f7466a_2 +- libpng=1.6.34=hb9fc6fc_0 +- libstdcxx-ng=7.2.0=h7a57d05_2 +- libtiff=4.0.9=h28f6b97_0 +- mkl=2018.0.1=h19d6760_4 +- nccl=1.3.4=cuda8.0_1 +- ncurses=6.0=h9df7e31_2 +- numpy=1.14.0=py36h3dfced4_1 +- olefile=0.45.1=py36_0 +- openssl=1.0.2n=hb7f436b_0 +- pillow=5.0.0=py36h3deb7b8_0 +- pip=9.0.1=py36h6c6f9ce_4 +- pycparser=2.18=py36hf9f622e_1 +- python=3.6.4=hc3d631a_1 +- pytorch=0.3.0=py36cuda8.0cudnn7.0_0 +- readline=7.0=ha6073c6_4 +- setuptools=38.4.0=py36_0 +- six=1.11.0=py36h372c433_1 +- sqlite=3.22.0=h1bed415_0 +- tk=8.6.7=hc745277_3 +- torchvision=0.2.0=py36_0 +- wheel=0.30.0=py36hfd4bba0_1 +- xz=5.2.3=h55aa19d_2 +- zlib=1.2.11=ha838bed_2 +- cuda90=1.0=h6433d27_0 +- pip: + - absl-py>=0.1.10 + - ansiwrap>=0.8.3 + - attrs>=17.4.0 + - bleach>=1.5.0 + - boto3>=1.5.26 + - botocore>=1.8.40 + - chardet>=3.0.4 + - click>=6.7 + - coverage>=4.5.1 + - cycler>=0.10.0 + - decorator>=4.2.1 + - docutils>=0.14 + - entrypoints>=0.2.3 + - future>=0.16.0 + - graphviz>=0.8.1 + - html5lib>=0.9999999 + - idna>=2.6 + - ipython>=6.2.1 + - ipython-genutils>=0.2.0 + - jedi>=0.11.1 + - jinja2>=2.10 + - jmespath>=0.9.3 + - jsonschema>=2.6.0 + - jupyter-client>=5.2.2 + - jupyter-core>=4.4.0 + - markdown>=2.6.11 + - markupsafe>=1.0 + - matplotlib>=2.1.2 + - mistune>=0.8.3 + - mxnet>=1.0.0.post4 + - nbconvert>=5.3.1 + - nbformat>=4.4.0 + - pandas>=0.22.0 + - pandocfilters>=1.4.2 + - papermill>=0.12.2 + - parso>=0.1.1 + - pexpect>=4.4.0 + - pickleshare>=0.7.4 + - pluggy>=0.6.0 + - prompt-toolkit>=1.0.15 + - protobuf>=3.5.1 + - ptyprocess>=0.5.2 + - py>=1.5.2 + - pygments>=2.2.0 + - pyhf>=0.0.4 + - pyparsing>=2.2.0 + - pytest>=3.4.0 + - pytest-cov>=2.5.1 + - python-dateutil>=2.6.1 + - pytz>=2018.3 + - pyyaml>=3.12 + - pyzmq>=17.0.0 + - requests>=2.18.4 + - s3transfer>=0.1.12 + - scipy>=1.0.0 + - simplegeneric>=0.8.1 + - tensorflow>=1.5.0 + - tensorflow-tensorboard>=1.5.1 + - testpath>=0.3.1 + - textwrap3>=0.9.1 + - torch>=0.3.0 + - tornado>=4.5.3 + - tqdm>=4.19.5 + - traitlets>=4.3.2 + - uproot>=2.6.10 + - urllib3>=1.22 + - wcwidth>=0.1.7 + - werkzeug>=0.14.1 diff --git a/binder/requirements.txt b/binder/requirements.txt deleted file mode 100644 index 9079a3c250..0000000000 --- a/binder/requirements.txt +++ /dev/null @@ -1,6 +0,0 @@ -matplotlib -numpy -papermill -pyhf -scipy -uproot diff --git a/examples/notebooks/example-mxnet.ipynb b/examples/notebooks/example-mxnet.ipynb new file mode 100644 index 0000000000..383245a6c3 --- /dev/null +++ b/examples/notebooks/example-mxnet.ipynb @@ -0,0 +1,96 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Populating the interactive namespace from numpy and matplotlib\n" + ] + } + ], + "source": [ + "%pylab inline" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import pyhf\n", + "from pyhf import hfpdf\n", + "from pyhf.simplemodels import hepdata_like\n", + "import mxnet as mx" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "--------\n", + "as MXNet\n", + "--------\n", + " \n", + "[-22.87784958]\n", + "\n" + ] + } + ], + "source": [ + "source = {\n", + " \"binning\": [2,-0.5,1.5],\n", + " \"bindata\": {\n", + " \"data\": [120.0, 180.0],\n", + " \"bkg\": [100.0, 150.0],\n", + " \"bkgerr\": [10.0, 10.0],\n", + " \"sig\": [30.0, 95.0]\n", + " }\n", + "}\n", + "\n", + "pdf = hepdata_like(source['bindata']['sig'], source['bindata']['bkg'], source['bindata']['bkgerr'])\n", + "data = source['bindata']['data'] + pdf.config.auxdata\n", + "\n", + "init_pars = pdf.config.suggested_init()\n", + "par_bounds = pdf.config.suggested_bounds()\n", + "\n", + "\n", + "print('--------\\nas MXNet\\n--------')\n", + "pyhf.tensorlib = pyhf.mxnet_backend()\n", + "v = pdf.logpdf(init_pars, data)\n", + "print(type(v),v)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.4" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/pyhf/__init__.py b/pyhf/__init__.py index 7286b8ab68..4016a18747 100644 --- a/pyhf/__init__.py +++ b/pyhf/__init__.py @@ -18,6 +18,12 @@ except ImportError: pass +try: + from .tensor.mxnet_backend import mxnet_backend + assert mxnet_backend +except ImportError: + pass + tensorlib = numpy_backend() optimizer = scipy_optimizer() diff --git a/pyhf/tensor/mxnet_backend.py b/pyhf/tensor/mxnet_backend.py new file mode 100644 index 0000000000..21aa72721b --- /dev/null +++ b/pyhf/tensor/mxnet_backend.py @@ -0,0 +1,320 @@ +from mxnet import nd +import logging +import math # Required for normal() +from numbers import Number # Required for normal() +log = logging.getLogger(__name__) + + +class mxnet_backend(object): + """MXNet backend for pyhf""" + + def __init__(self, **kwargs): + pass + + def tolist(self, tensor_in): + """ + Convert a tensor to a list. + + Args: + tensor_in (Tensor): Input MXNet tensor + + Returns: + list: The possibly nested list of tensor elements. + """ + tensor_in = self.astensor(tensor_in) + return tensor_in.asnumpy().tolist() + + def outer(self, tensor_in_1, tensor_in_2): + """ + The outer product of two tensors. + + Args: + tensor_in_1 (Tensor): Tensor object + tensor_in_2 (Tensor): Tensor object + + Returns: + MXNet NDArray: The outer product. + """ + tensor_in_1 = self.astensor(tensor_in_1) + tensor_in_2 = self.astensor(tensor_in_2) + + tensor_1_shape = tensor_in_1.shape + tensor_2_shape = tensor_in_2.shape + if len(tensor_1_shape) == 1: + tensor_1_shape = (tensor_1_shape[0], 1) + if len(tensor_2_shape) == 1: + tensor_2_shape = (tensor_2_shape[0], 1) + + rows1, cols1 = tensor_1_shape + rows2, cols2 = tensor_2_shape + return nd.reshape(nd.broadcast_mul(tensor_in_1.reshape((rows1, 1, cols1, 1)), + tensor_in_2.reshape((1, rows2, 1, cols2))), + (rows1 * cols1, rows2 * cols2)) + + def astensor(self, tensor_in): + """ + Convert a tensor to an MXNet NDArray. + + Args: + tensor_in (Number or Tensor): Tensor object + + Returns: + MXNet NDArray: A multi-dimensional, fixed-size homogenous array. + """ + return nd.array(tensor_in) + + def sum(self, tensor_in, axis=None): + """ + Compute the sum of array elements over given axes. + + Args: + tensor_in (Tensor): Tensor object + axis (Number): The axes over which to sum + + Returns: + MXNet NDArray: ndarray of the sum over the axes. + """ + tensor_in = self.astensor(tensor_in) + if axis is None or tensor_in.shape == nd.array([]).size: + return nd.sum(tensor_in) + else: + return nd.sum(tensor_in, axis) + + def product(self, tensor_in, axis=None): + """ + Product of array elements over given axes. + + Args: + tensor_in (Tensor): Tensor object + axis (Number): The axes over which to take the product + + Returns: + MXNet NDArray: ndarray of the product over the axes. + """ + tensor_in = self.astensor(tensor_in) + if axis is None: + return nd.prod(tensor_in) + else: + return nd.prod(tensor_in, axis) + + def ones(self, shape): + """ + A new array filled with all ones, with the given shape. + + Args: + shape (Number): The shape of the array + + Returns: + MXNet NDArray: ndarray of 1's with given shape. + """ + return nd.ones(shape) + + def zeros(self, shape): + """ + A new array filled with all zeros, with the given shape. + + Args: + shape (Number): The shape of the array + + Returns: + MXNet NDArray: ndarray of 0's with given shape. + """ + return nd.zeros(shape) + + def power(self, tensor_in_1, tensor_in_2): + """ + Result of first array elements raised to powers from second array, + element-wise with broadcasting. + + Args: + tensor_in_1 (Tensor): Tensor object + tensor_in_2 (Tensor): Tensor object + + Returns: + MXNet NDArray: First array elements raised to powers from second array. + """ + tensor_in_1 = self.astensor(tensor_in_1) + tensor_in_2 = self.astensor(tensor_in_2) + return nd.power(tensor_in_1, tensor_in_2) + + def sqrt(self, tensor_in): + """ + Element-wise square-root value of the input. + + Args: + tensor_in (Tensor): Tensor object + + Returns: + MXNet NDArray: Element-wise square-root value. + """ + tensor_in = self.astensor(tensor_in) + return nd.sqrt(tensor_in) + + def divide(self, tensor_in_1, tensor_in_2): + """ + Element-wise division of the input arrays with broadcasting. + + Args: + tensor_in_1 (Tensor): Tensor object + tensor_in_2 (Tensor): Tensor object + + Returns: + MXNet NDArray: Element-wise division of the input arrays. + """ + tensor_in_1 = self.astensor(tensor_in_1) + tensor_in_2 = self.astensor(tensor_in_2) + return nd.divide(tensor_in_1, tensor_in_2) + + def log(self, tensor_in): + """ + Element-wise Natural logarithmic value of the input. + + Args: + tensor_in (Tensor): Tensor object + + Returns: + MXNet NDArray: Element-wise Natural logarithmic value. + """ + tensor_in = self.astensor(tensor_in) + return nd.log(tensor_in) + + def exp(self, tensor_in): + """ + Element-wise exponential value of the input. + + Args: + tensor_in (Tensor): Tensor object + + Returns: + MXNet NDArray: Element-wise exponential value. + """ + tensor_in = self.astensor(tensor_in) + return nd.exp(tensor_in) + + def stack(self, sequence, axis=0): + """ + Join a sequence of arrays along a new axis. + + The axis parameter specifies the index of the new axis in the dimensions + of the result. For example, if axis=0 it will be the first dimension and + if axis=-1 it will be the last dimension. + + Args: + sequence (Array of Tensors): Sequence of arrays + axis (Number): The axis along which to join the arrays + + Returns: + MXNet NDArray: ndarray comprised of the elements of the sequence. + """ + return nd.stack(*sequence, axis=axis) + + def where(self, mask, tensor_in_1, tensor_in_2): + """ + Apply a boolean selection mask to the elements of the input tensors. + + Example:: + + >>> where( + astensor([1, 0, 1]), + astensor([1, 1, 1]), + astensor([2, 2, 2])) + [1. 2. 1.] + + Args: + mask (bool): Boolean mask (boolean or tensor object of booleans) + tensor_in_1 (Tensor): Tensor object + tensor_in_2 (Tensor): Tensor object + + Returns: + MXNet NDArray: The result of the mask being applied to the tensors. + """ + mask = self.astensor(mask) + tensor_in_1 = self.astensor(tensor_in_1) + tensor_in_2 = self.astensor(tensor_in_2) + return nd.add(nd.multiply(mask, tensor_in_1), + nd.multiply(nd.subtract(1, mask), tensor_in_2)) + + def concatenate(self, sequence): + """ + Join the elements of the sequence. + + Args: + sequence (Array of Tensors): The sequence of arrays to join + + Returns: + MXNet NDArray: The ndarray of the joined elements. + """ + return nd.concat(*sequence, dim=0) + + def simple_broadcast(self, *args): + """ + Broadcast a sequence of 1 dimensional arrays. + + Example:: + + >>> simple_broadcast( + astensor([1]), + astensor([2, 2]), + astensor([3, 3, 3])) + [[1. 1. 1.] + [2. 2. 2.] + [3. 3. 3.]] + + Args: + args (Array of Tensors): Sequence of arrays + + Returns: + MXNet NDArray: The sequence broadcast together. + """ + max_dim = max(map(len, args)) + broadcast = [] + for arg in args: + if len(arg) < max_dim: + broadcast.append(nd.broadcast_axis( + arg[0], axis=len(arg.shape) - 1, size=max_dim)) + else: + broadcast.append(arg) + return nd.stack(*broadcast) + + def poisson(self, n, lam): + """ + The continous approximation to the probability density function of the Poisson + distribution given the parameters evaluated at `n`. + + Args: + n (Number or Tensor): The value at which to evaluate the Poisson distribution p.d.f. + (the observed number of events) + lam (Number or Tensor): The mean of the Poisson distribution p.d.f. + (the expected number of events) + + Returns: + MXNet NDArray: Value of N(n|lam, sqrt(lam)), the continous approximation to Poisson(n|lam). + """ + return self.normal(n, lam, self.sqrt(lam)) + + def normal(self, x, mu, sigma): + """ + The probability density function of the Normal distribution given the parameters + evaluated at `x`. + + Args: + x (Number or Tensor): The point at which to evaluate the Normal distribution p.d.f. + mu (Number or Tensor): The mean of the Normal distribution p.d.f. + sigma(Number or Tensor): The standard deviation of the Normal distribution p.d.f. + + Returns: + MXNet NDArray: Value of N(x|mu, sigma). + """ + x = self.astensor(x) + mu = self.astensor(mu) + sigma = self.astensor(sigma) + + # This is currently copied directly from PyTorch's source until a better + # way can be found to do this in MXNet + # https://github.com/pytorch/pytorch/blob/master/torch/distributions/normal.py#L61-L66 + def log_prob(value, loc, scale): + variance = scale ** 2 + log_scale = math.log(scale) if isinstance( + scale, Number) else scale.log() + return -((value - loc) ** 2) / (2 * variance) - log_scale - math.log(math.sqrt(2 * math.pi)) + return self.exp(log_prob(x, mu, sigma)) diff --git a/setup.py b/setup.py index 6d7b17ea38..6fc0f71add 100644 --- a/setup.py +++ b/setup.py @@ -19,6 +19,9 @@ 'torch': [ 'torch' ], + 'mxnet':[ + 'mxnet', + ], 'develop': [ 'pyflakes', 'pytest>=3.2.0', @@ -29,7 +32,9 @@ 'uproot', 'papermill', 'torch', - 'tensorflow' + 'tensorflow', + 'mxnet>=1.0.0', + 'graphviz' ] }, entry_points = { diff --git a/tests/test_tensor.py b/tests/test_tensor.py index 29e523b9b3..99f7c95b45 100644 --- a/tests/test_tensor.py +++ b/tests/test_tensor.py @@ -1,33 +1,39 @@ from pyhf.tensor.pytorch_backend import pytorch_backend from pyhf.tensor.numpy_backend import numpy_backend from pyhf.tensor.tensorflow_backend import tensorflow_backend +from pyhf.tensor.mxnet_backend import mxnet_backend from pyhf.simplemodels import hepdata_like import tensorflow as tf def test_common_tensor_backends(): tf_sess = tf.Session() - for tb in [numpy_backend(), pytorch_backend(), tensorflow_backend(session = tf_sess)]: - assert tb.tolist(tb.astensor([1,2,3])) == [1,2,3] - assert tb.tolist(tb.ones((2,3))) == [[1,1,1],[1,1,1]] - assert tb.tolist(tb.sum([[1,2,3],[4,5,6]], axis = 0)) == [5,7,9] - assert tb.tolist(tb.product([[1,2,3],[4,5,6]], axis = 0)) == [4,10,18] - assert tb.tolist(tb.power([1,2,3],[1,2,3])) == [1,4,27] - assert tb.tolist(tb.divide([4,9,16],[2,3,4])) == [2,3,4] - assert tb.tolist(tb.outer([1,2,3],[4,5,6])) == [[4,5,6],[8,10,12],[12,15,18]] - assert tb.tolist(tb.sqrt([4,9,16])) == [2,3,4] - assert tb.tolist(tb.stack([tb.astensor([1,2,3]),tb.astensor([4,5,6])])) == [[1,2,3],[4,5,6]] - assert tb.tolist(tb.concatenate([tb.astensor([1,2,3]),tb.astensor([4,5,6])])) == [1,2,3,4,5,6] - assert tb.tolist(tb.log(tb.exp([2,3,4]))) == [2,3,4] + for tb in [numpy_backend(), pytorch_backend(), + tensorflow_backend(session=tf_sess), mxnet_backend()]: + assert tb.tolist(tb.astensor([1, 2, 3])) == [1, 2, 3] + assert tb.tolist(tb.ones((2, 3))) == [[1, 1, 1], [1, 1, 1]] + assert tb.tolist(tb.sum([[1, 2, 3], [4, 5, 6]], axis=0)) == [5, 7, 9] + assert tb.tolist( + tb.product([[1, 2, 3], [4, 5, 6]], axis=0)) == [4, 10, 18] + assert tb.tolist(tb.power([1, 2, 3], [1, 2, 3])) == [1, 4, 27] + assert tb.tolist(tb.divide([4, 9, 16], [2, 3, 4])) == [2, 3, 4] + assert tb.tolist( + tb.outer([1, 2, 3], [4, 5, 6])) == [[4, 5, 6], [8, 10, 12], [12, 15, 18]] + assert tb.tolist(tb.sqrt([4, 9, 16])) == [2, 3, 4] + assert tb.tolist(tb.stack( + [tb.astensor([1, 2, 3]), tb.astensor([4, 5, 6])])) == [[1, 2, 3], [4, 5, 6]] + assert tb.tolist(tb.concatenate( + [tb.astensor([1, 2, 3]), tb.astensor([4, 5, 6])])) == [1, 2, 3, 4, 5, 6] + assert tb.tolist(tb.log(tb.exp([2, 3, 4]))) == [2, 3, 4] assert tb.tolist(tb.where( - tb.astensor([1,0,1]), - tb.astensor([1,1,1]), - tb.astensor([2,2,2]))) == [1,2,1] + tb.astensor([1, 0, 1]), + tb.astensor([1, 1, 1]), + tb.astensor([2, 2, 2]))) == [1, 2, 1] - assert list(map(tb.tolist,tb.simple_broadcast( - tb.astensor([1,1,1]), + assert list(map(tb.tolist, tb.simple_broadcast( + tb.astensor([1, 1, 1]), tb.astensor([2]), - tb.astensor([3,3,3])))) == [[1,1,1],[2,2,2],[3,3,3]] + tb.astensor([3, 3, 3])))) == [[1, 1, 1], [2, 2, 2], [3, 3, 3]] def test_pdf_eval(): @@ -36,38 +42,40 @@ def test_pdf_eval(): oldlib = pyhf.tensorlib tf_sess = tf.Session() - backends = [numpy_backend(poisson_from_normal = True), pytorch_backend(), tensorflow_backend(session = tf_sess)] + backends = [numpy_backend(poisson_from_normal=True), + pytorch_backend(), + tensorflow_backend(session=tf_sess), + mxnet_backend()] values = [] for b in backends: - pyhf.tensorlib = b source = { - "binning": [2,-0.5,1.5], - "bindata": { - "data": [120.0, 180.0], - "bkg": [100.0, 150.0], - "bkgsys_up": [102, 190], - "bkgsys_dn": [98, 100], - "sig": [30.0, 95.0] - } + "binning": [2, -0.5, 1.5], + "bindata": { + "data": [120.0, 180.0], + "bkg": [100.0, 150.0], + "bkgsys_up": [102, 190], + "bkgsys_dn": [98, 100], + "sig": [30.0, 95.0] + } } spec = { 'singlechannel': { 'signal': { 'data': source['bindata']['sig'], - 'mods': [{'name': 'mu','type': 'normfactor','data': None}] + 'mods': [{'name': 'mu', 'type': 'normfactor', 'data': None}] }, 'background': { 'data': source['bindata']['bkg'], - 'mods': [{'name': 'bkg_norm','type': 'histosys','data': { + 'mods': [{'name': 'bkg_norm', 'type': 'histosys', 'data': { 'lo_hist': source['bindata']['bkgsys_dn'], 'hi_hist': source['bindata']['bkgsys_up'], }}] } } } - pdf = pyhf.hfpdf(spec) + pdf = pyhf.hfpdf(spec) data = source['bindata']['data'] + pdf.config.auxdata v1 = pdf.logpdf(pdf.config.suggested_init(), data) @@ -84,24 +92,27 @@ def test_pdf_eval_2(): oldlib = pyhf.tensorlib tf_sess = tf.Session() - backends = [numpy_backend(poisson_from_normal = True), pytorch_backend(), tensorflow_backend(session = tf_sess)] + backends = [numpy_backend(poisson_from_normal=True), + pytorch_backend(), + tensorflow_backend(session=tf_sess), + mxnet_backend()] values = [] for b in backends: - pyhf.tensorlib = b source = { - "binning": [2,-0.5,1.5], - "bindata": { - "data": [120.0, 180.0], - "bkg": [100.0, 150.0], - "bkgerr": [10.0, 10.0], - "sig": [30.0, 95.0] - } + "binning": [2, -0.5, 1.5], + "bindata": { + "data": [120.0, 180.0], + "bkg": [100.0, 150.0], + "bkgerr": [10.0, 10.0], + "sig": [30.0, 95.0] + } } - pdf = hepdata_like(source['bindata']['sig'], source['bindata']['bkg'], source['bindata']['bkgerr']) + pdf = hepdata_like(source['bindata']['sig'], source['bindata'][ + 'bkg'], source['bindata']['bkgerr']) data = source['bindata']['data'] + pdf.config.auxdata v1 = pdf.logpdf(pdf.config.suggested_init(), data)