From faac9ebed5cc7e87b45fc9ff51a077131baaaa92 Mon Sep 17 00:00:00 2001 From: Matthew Feickert Date: Sat, 10 Feb 2018 01:22:20 +0100 Subject: [PATCH 01/10] Add placeholder MXNet backend Add a placeholder file for the MXNet backend and partially fill it in. In addition also add mxnet to packages to pip install and add mxnet_backend to the list of backends to test. --- pyhf/tensor/mxnet_backend.py | 183 +++++++++++++++++++++++++++++++++++ setup.py | 7 +- tests/test_tensor.py | 93 ++++++++++-------- 3 files changed, 241 insertions(+), 42 deletions(-) create mode 100644 pyhf/tensor/mxnet_backend.py diff --git a/pyhf/tensor/mxnet_backend.py b/pyhf/tensor/mxnet_backend.py new file mode 100644 index 0000000000..36e72e17b9 --- /dev/null +++ b/pyhf/tensor/mxnet_backend.py @@ -0,0 +1,183 @@ +import mxnet as mx +from mxnet import nd +import logging +log = logging.getLogger(__name__) + + +class mxnet_backend(object): + """Backend for MXNet""" + + def __init__(self, **kwargs): + self.session = kwargs.get('session') + + def tolist(self, tensor_in): + """ + Convert a tensor to a list + + Args: + tensor_in: MXNet tensor + + Returns: + The possibly nested list of tensor elements. + """ + tensor_in = self.astensor(tensor_in) + return tensor_in.asnumpy().tolist() + + def outer(self, tensor_in_1, tensor_in_2): + """ + The outer product of two tensors: u.v^T + + Args: + tensor_in_1: tensor object + tensor_in_2: tensor object + + Returns: + MXNet NDArray: The outer product + """ + tensor_in_1 = self.astensor(tensor_in_1) + tensor_in_2 = self.astensor(tensor_in_2) + pass + + def astensor(self, tensor_in): + """ + Convert a tensor to an MXNet NDArray + + Args: + tensor_in: tensor object + + Returns: + MXNet NDArray: a multi-dimensional, fixed-size homogenous array. + """ + return nd.array(tensor_in) + + def sum(self, tensor_in, axis=None): + """ + Compute the sum of array elements over given axes. + + Args: + tensor_in: tensor object + axis: the axes over which to sum + + Returns: + MXNet NDArray: ndarray of the sum over the axes + """ + tensor_in = self.astensor(tensor_in) + if axis is None or tensor_in.shape == nd.Size([]): + return nd.sum(tensor_in) + else: + return nd.sum(tensor_in, axis) + + def product(self, tensor_in, axis=None): + pass + + def ones(self, shape): + """ + A new array filled with all ones, with the given shape. + + Args: + shape: the shape of the array + + Returns: + MXNet NDArray: ndarray of 1's with given shape + """ + return nd.ones(shape) + + def zeros(self, shape): + """ + A new array filled with all zeros, with the given shape. + + Args: + shape: the shape of the array + + Returns: + MXNet NDArray: ndarray of 0's with given shape + """ + return nd.zeros(shape) + + def power(self, tensor_in_1, tensor_in_2): + """ + Result of first array elements raised to powers from second array, + element-wise with broadcasting. + + Args: + tensor_in_1: tensor object + tensor_in_2: tensor object + + Returns: + MXNet NDArray: first array elements raised to powers from second array + """ + tensor_in_1 = self.astensor(tensor_in_1) + tensor_in_2 = self.astensor(tensor_in_2) + return nd.power(tensor_in_1, tensor_in_2) + + def sqrt(self, tensor_in): + """ + Element-wise square-root value of the input. + + Args: + tensor_in: tensor object + + Returns: + MXNet NDArray: element-wise square-root value + """ + tensor_in = self.astensor(tensor_in) + return nd.sqrt(tensor_in) + + def divide(self, tensor_in_1, tensor_in_2): + """ + Element-wise division of the input arrays with broadcasting. + + Args: + tensor_in_1: tensor object + tensor_in_2: tensor object + + Returns: + MXNet NDArray: element-wise division of the input arrays + """ + tensor_in_1 = self.astensor(tensor_in_1) + tensor_in_2 = self.astensor(tensor_in_2) + return nd.divide(tensor_in_1, tensor_in_2) + + def log(self, tensor_in): + """ + Element-wise Natural logarithmic value of the input. + + Args: + tensor_in: tensor object + + Returns: + MXNet NDArray: element-wise Natural logarithmic value + """ + tensor_in = self.astensor(tensor_in) + return nd.log(tensor_in) + + def exp(self, tensor_in): + """ + Element-wise exponential value of the input. + + Args: + tensor_in: tensor object + + Returns: + MXNet NDArray: element-wise exponential value + """ + tensor_in = self.astensor(tensor_in) + return nd.exp(tensor_in) + + def stack(self, sequence, axis=0): + pass + + def where(self, mask, tensor_in_1, tensor_in_2): + pass + + def concatenate(self, sequence): + pass + + def simple_broadcast(self, *args): + pass + + def poisson(self, n, lam): + pass + + def normal(self, x, mu, sigma): + pass diff --git a/setup.py b/setup.py index 6d7b17ea38..6fc0f71add 100644 --- a/setup.py +++ b/setup.py @@ -19,6 +19,9 @@ 'torch': [ 'torch' ], + 'mxnet':[ + 'mxnet', + ], 'develop': [ 'pyflakes', 'pytest>=3.2.0', @@ -29,7 +32,9 @@ 'uproot', 'papermill', 'torch', - 'tensorflow' + 'tensorflow', + 'mxnet>=1.0.0', + 'graphviz' ] }, entry_points = { diff --git a/tests/test_tensor.py b/tests/test_tensor.py index 29e523b9b3..99f7c95b45 100644 --- a/tests/test_tensor.py +++ b/tests/test_tensor.py @@ -1,33 +1,39 @@ from pyhf.tensor.pytorch_backend import pytorch_backend from pyhf.tensor.numpy_backend import numpy_backend from pyhf.tensor.tensorflow_backend import tensorflow_backend +from pyhf.tensor.mxnet_backend import mxnet_backend from pyhf.simplemodels import hepdata_like import tensorflow as tf def test_common_tensor_backends(): tf_sess = tf.Session() - for tb in [numpy_backend(), pytorch_backend(), tensorflow_backend(session = tf_sess)]: - assert tb.tolist(tb.astensor([1,2,3])) == [1,2,3] - assert tb.tolist(tb.ones((2,3))) == [[1,1,1],[1,1,1]] - assert tb.tolist(tb.sum([[1,2,3],[4,5,6]], axis = 0)) == [5,7,9] - assert tb.tolist(tb.product([[1,2,3],[4,5,6]], axis = 0)) == [4,10,18] - assert tb.tolist(tb.power([1,2,3],[1,2,3])) == [1,4,27] - assert tb.tolist(tb.divide([4,9,16],[2,3,4])) == [2,3,4] - assert tb.tolist(tb.outer([1,2,3],[4,5,6])) == [[4,5,6],[8,10,12],[12,15,18]] - assert tb.tolist(tb.sqrt([4,9,16])) == [2,3,4] - assert tb.tolist(tb.stack([tb.astensor([1,2,3]),tb.astensor([4,5,6])])) == [[1,2,3],[4,5,6]] - assert tb.tolist(tb.concatenate([tb.astensor([1,2,3]),tb.astensor([4,5,6])])) == [1,2,3,4,5,6] - assert tb.tolist(tb.log(tb.exp([2,3,4]))) == [2,3,4] + for tb in [numpy_backend(), pytorch_backend(), + tensorflow_backend(session=tf_sess), mxnet_backend()]: + assert tb.tolist(tb.astensor([1, 2, 3])) == [1, 2, 3] + assert tb.tolist(tb.ones((2, 3))) == [[1, 1, 1], [1, 1, 1]] + assert tb.tolist(tb.sum([[1, 2, 3], [4, 5, 6]], axis=0)) == [5, 7, 9] + assert tb.tolist( + tb.product([[1, 2, 3], [4, 5, 6]], axis=0)) == [4, 10, 18] + assert tb.tolist(tb.power([1, 2, 3], [1, 2, 3])) == [1, 4, 27] + assert tb.tolist(tb.divide([4, 9, 16], [2, 3, 4])) == [2, 3, 4] + assert tb.tolist( + tb.outer([1, 2, 3], [4, 5, 6])) == [[4, 5, 6], [8, 10, 12], [12, 15, 18]] + assert tb.tolist(tb.sqrt([4, 9, 16])) == [2, 3, 4] + assert tb.tolist(tb.stack( + [tb.astensor([1, 2, 3]), tb.astensor([4, 5, 6])])) == [[1, 2, 3], [4, 5, 6]] + assert tb.tolist(tb.concatenate( + [tb.astensor([1, 2, 3]), tb.astensor([4, 5, 6])])) == [1, 2, 3, 4, 5, 6] + assert tb.tolist(tb.log(tb.exp([2, 3, 4]))) == [2, 3, 4] assert tb.tolist(tb.where( - tb.astensor([1,0,1]), - tb.astensor([1,1,1]), - tb.astensor([2,2,2]))) == [1,2,1] + tb.astensor([1, 0, 1]), + tb.astensor([1, 1, 1]), + tb.astensor([2, 2, 2]))) == [1, 2, 1] - assert list(map(tb.tolist,tb.simple_broadcast( - tb.astensor([1,1,1]), + assert list(map(tb.tolist, tb.simple_broadcast( + tb.astensor([1, 1, 1]), tb.astensor([2]), - tb.astensor([3,3,3])))) == [[1,1,1],[2,2,2],[3,3,3]] + tb.astensor([3, 3, 3])))) == [[1, 1, 1], [2, 2, 2], [3, 3, 3]] def test_pdf_eval(): @@ -36,38 +42,40 @@ def test_pdf_eval(): oldlib = pyhf.tensorlib tf_sess = tf.Session() - backends = [numpy_backend(poisson_from_normal = True), pytorch_backend(), tensorflow_backend(session = tf_sess)] + backends = [numpy_backend(poisson_from_normal=True), + pytorch_backend(), + tensorflow_backend(session=tf_sess), + mxnet_backend()] values = [] for b in backends: - pyhf.tensorlib = b source = { - "binning": [2,-0.5,1.5], - "bindata": { - "data": [120.0, 180.0], - "bkg": [100.0, 150.0], - "bkgsys_up": [102, 190], - "bkgsys_dn": [98, 100], - "sig": [30.0, 95.0] - } + "binning": [2, -0.5, 1.5], + "bindata": { + "data": [120.0, 180.0], + "bkg": [100.0, 150.0], + "bkgsys_up": [102, 190], + "bkgsys_dn": [98, 100], + "sig": [30.0, 95.0] + } } spec = { 'singlechannel': { 'signal': { 'data': source['bindata']['sig'], - 'mods': [{'name': 'mu','type': 'normfactor','data': None}] + 'mods': [{'name': 'mu', 'type': 'normfactor', 'data': None}] }, 'background': { 'data': source['bindata']['bkg'], - 'mods': [{'name': 'bkg_norm','type': 'histosys','data': { + 'mods': [{'name': 'bkg_norm', 'type': 'histosys', 'data': { 'lo_hist': source['bindata']['bkgsys_dn'], 'hi_hist': source['bindata']['bkgsys_up'], }}] } } } - pdf = pyhf.hfpdf(spec) + pdf = pyhf.hfpdf(spec) data = source['bindata']['data'] + pdf.config.auxdata v1 = pdf.logpdf(pdf.config.suggested_init(), data) @@ -84,24 +92,27 @@ def test_pdf_eval_2(): oldlib = pyhf.tensorlib tf_sess = tf.Session() - backends = [numpy_backend(poisson_from_normal = True), pytorch_backend(), tensorflow_backend(session = tf_sess)] + backends = [numpy_backend(poisson_from_normal=True), + pytorch_backend(), + tensorflow_backend(session=tf_sess), + mxnet_backend()] values = [] for b in backends: - pyhf.tensorlib = b source = { - "binning": [2,-0.5,1.5], - "bindata": { - "data": [120.0, 180.0], - "bkg": [100.0, 150.0], - "bkgerr": [10.0, 10.0], - "sig": [30.0, 95.0] - } + "binning": [2, -0.5, 1.5], + "bindata": { + "data": [120.0, 180.0], + "bkg": [100.0, 150.0], + "bkgerr": [10.0, 10.0], + "sig": [30.0, 95.0] + } } - pdf = hepdata_like(source['bindata']['sig'], source['bindata']['bkg'], source['bindata']['bkgerr']) + pdf = hepdata_like(source['bindata']['sig'], source['bindata'][ + 'bkg'], source['bindata']['bkgerr']) data = source['bindata']['data'] + pdf.config.auxdata v1 = pdf.logpdf(pdf.config.suggested_init(), data) From 0e8b5e41a755ead2910fe412daf0a7c5b4d4d1ab Mon Sep 17 00:00:00 2001 From: Matthew Feickert Date: Sun, 11 Feb 2018 14:13:39 +0100 Subject: [PATCH 02/10] Add MXNet backend stack() and concatenate() Other methods are added as well, but these are the two that were the most important --- pyhf/tensor/mxnet_backend.py | 98 ++++++++++++++++++++++++++++++++---- 1 file changed, 89 insertions(+), 9 deletions(-) diff --git a/pyhf/tensor/mxnet_backend.py b/pyhf/tensor/mxnet_backend.py index 36e72e17b9..9696b01724 100644 --- a/pyhf/tensor/mxnet_backend.py +++ b/pyhf/tensor/mxnet_backend.py @@ -8,7 +8,7 @@ class mxnet_backend(object): """Backend for MXNet""" def __init__(self, **kwargs): - self.session = kwargs.get('session') + pass def tolist(self, tensor_in): """ @@ -62,13 +62,27 @@ def sum(self, tensor_in, axis=None): MXNet NDArray: ndarray of the sum over the axes """ tensor_in = self.astensor(tensor_in) - if axis is None or tensor_in.shape == nd.Size([]): + if axis is None or tensor_in.shape == nd.array([]).size: return nd.sum(tensor_in) else: return nd.sum(tensor_in, axis) def product(self, tensor_in, axis=None): - pass + """ + Product of array elements over given axes. + + Args: + tensor_in: tensor object + axis: the axes over which to take the product + + Returns: + MXNet NDArray: ndarray of the product over the axes + """ + tensor_in = self.astensor(tensor_in) + if axis is None: + return nd.prod(tensor_in) + else: + return nd.prod(tensor_in, axis) def ones(self, shape): """ @@ -165,19 +179,85 @@ def exp(self, tensor_in): return nd.exp(tensor_in) def stack(self, sequence, axis=0): - pass + """ + Join a sequence of arrays along a new axis. + + The axis parameter specifies the index of the new axis in the dimensions + of the result. For example, if axis=0 it will be the first dimension and + if axis=-1 it will be the last dimension. + + Args: + sequence: sequence of arrays + axis: the axis along which to join the arrays + + Returns: + MXNet NDArray: ndarray comprised of the elements of the sequence + """ + return nd.stack(*sequence, axis=axis) def where(self, mask, tensor_in_1, tensor_in_2): - pass + """ + Apply a boolean selection mask to the elements of the input tensors + + Example: + >>> mxnet_backend.where( + mxnet_backend.astensor([1, 0, 1]), + mxnet_backend.astensor([1, 1, 1]), + mxnet_backend.astensor([2, 2, 2])) + >>> [1. 2. 1.] + + Args: + mask: Boolean mask (boolean or tensor object of booleans) + tensor_in_1: tensor object + tensor_in_2: tensor object + + Returns: + MXNet NDArray: The result of the mask being applied to the tensors + """ + mask = self.astensor(mask) + tensor_in_1 = self.astensor(tensor_in_1) + tensor_in_2 = self.astensor(tensor_in_2) + return nd.multiply(mask, tensor_in_1) + nd.multiply(nd.subtract(1, mask), tensor_in_2) def concatenate(self, sequence): - pass + """ + Join the elements of the sequence + + Args: + sequence: the sequence of arrays to join + + Returns: + MXNet NDArray: The ndarray of the joined elements + """ + return nd.concat(*sequence, dim=0) def simple_broadcast(self, *args): - pass + """ + Does this work? + There should be a more MXNet-style way to do this + """ + broadcast = [] + max_dim = max(map(len, args)) + for arg in args: + broadcast.append(self.astensor(arg) + if len(arg) > 1 else arg * self.ones(max_dim)) + return broadcast def poisson(self, n, lam): - pass + return self.normal(n, lam, self.sqrt(lam)) def normal(self, x, mu, sigma): - pass + import math + from numbers import Number + x = self.astensor(x) + mu = self.astensor(mu) + sigma = self.astensor(sigma) + # Is needed? + # normal = nd.random.normal(loc=mu, scale=sigma) + + def log_prob(value, loc, scale): + variance = scale ** 2 + log_scale = math.log(scale) if isinstance( + scale, Number) else scale.log() + return -((value - loc) ** 2) / (2 * variance) - log_scale - math.log(math.sqrt(2 * math.pi)) + return self.exp(log_prob(x, mu, sigma)) From f540dde1ed060a3a6b8fe8eb242aa03e5c702dbd Mon Sep 17 00:00:00 2001 From: Matthew Feickert Date: Sun, 11 Feb 2018 17:51:52 +0100 Subject: [PATCH 03/10] Add placeholder for outer() The current outer() function is highly suboptimal. However, until a better solution can be implimented in a more MXNet-centric style it will at least pass the tests. The Pull Request in which this commit exists should _NOT_ be merged in until outer() is updated with a better solution. --- pyhf/tensor/mxnet_backend.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/pyhf/tensor/mxnet_backend.py b/pyhf/tensor/mxnet_backend.py index 9696b01724..0f12adee81 100644 --- a/pyhf/tensor/mxnet_backend.py +++ b/pyhf/tensor/mxnet_backend.py @@ -1,6 +1,6 @@ -import mxnet as mx from mxnet import nd import logging +import itertools # Hack fix for mxnet_backend.outer() log = logging.getLogger(__name__) @@ -34,9 +34,16 @@ def outer(self, tensor_in_1, tensor_in_2): Returns: MXNet NDArray: The outer product """ + # This is currently a rather stupid way to do things, so need to fix this + # Currently also is assuming only 1-d tensors, which is bad tensor_in_1 = self.astensor(tensor_in_1) tensor_in_2 = self.astensor(tensor_in_2) - pass + tensor_in_2 = tensor_in_2.T + outer = nd.ones((tensor_in_1.size, tensor_in_2.size)) + for i, j in itertools.product(range(tensor_in_1.size), + range(tensor_in_2.size)): + outer[i, j] = nd.dot(tensor_in_1[i], tensor_in_2[j]) + return outer def astensor(self, tensor_in): """ @@ -233,7 +240,6 @@ def concatenate(self, sequence): def simple_broadcast(self, *args): """ - Does this work? There should be a more MXNet-style way to do this """ broadcast = [] @@ -247,13 +253,14 @@ def poisson(self, n, lam): return self.normal(n, lam, self.sqrt(lam)) def normal(self, x, mu, sigma): + """ + Currently copying from PyTorch's source until can find a better way to do this + """ import math from numbers import Number x = self.astensor(x) mu = self.astensor(mu) sigma = self.astensor(sigma) - # Is needed? - # normal = nd.random.normal(loc=mu, scale=sigma) def log_prob(value, loc, scale): variance = scale ** 2 From bf92334979562c221bf033a0afea7b0533d79b1c Mon Sep 17 00:00:00 2001 From: Matthew Feickert Date: Mon, 12 Feb 2018 01:00:26 +0100 Subject: [PATCH 04/10] Use MXNet-centric implimentation of outer() Use the linear algebra functions of MXNet to provide a more clean implimenation of the outer product method, outer() --- pyhf/tensor/mxnet_backend.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/pyhf/tensor/mxnet_backend.py b/pyhf/tensor/mxnet_backend.py index 0f12adee81..398ba7bee1 100644 --- a/pyhf/tensor/mxnet_backend.py +++ b/pyhf/tensor/mxnet_backend.py @@ -1,6 +1,5 @@ from mxnet import nd import logging -import itertools # Hack fix for mxnet_backend.outer() log = logging.getLogger(__name__) @@ -25,7 +24,7 @@ def tolist(self, tensor_in): def outer(self, tensor_in_1, tensor_in_2): """ - The outer product of two tensors: u.v^T + The outer product of two tensors Args: tensor_in_1: tensor object @@ -34,16 +33,21 @@ def outer(self, tensor_in_1, tensor_in_2): Returns: MXNet NDArray: The outer product """ - # This is currently a rather stupid way to do things, so need to fix this - # Currently also is assuming only 1-d tensors, which is bad tensor_in_1 = self.astensor(tensor_in_1) tensor_in_2 = self.astensor(tensor_in_2) - tensor_in_2 = tensor_in_2.T - outer = nd.ones((tensor_in_1.size, tensor_in_2.size)) - for i, j in itertools.product(range(tensor_in_1.size), - range(tensor_in_2.size)): - outer[i, j] = nd.dot(tensor_in_1[i], tensor_in_2[j]) - return outer + + tensor_1_shape = tensor_in_1.shape + tensor_2_shape = tensor_in_2.shape + if len(tensor_1_shape) == 1: + tensor_1_shape = (*tensor_1_shape, 1) + if len(tensor_2_shape) == 1: + tensor_2_shape = (*tensor_2_shape, 1) + + rows1, cols1 = tensor_1_shape + rows2, cols2 = tensor_2_shape + return nd.reshape(nd.dot(tensor_in_1.reshape((rows1, 1, cols1, 1)), + tensor_in_2.reshape((1, rows2, 1, cols2))), + (rows1 * cols1, rows2 * cols2)) def astensor(self, tensor_in): """ From 9fd4d013a23501272d6b9ed636ccee1daf3382d4 Mon Sep 17 00:00:00 2001 From: Matthew Feickert Date: Mon, 12 Feb 2018 02:07:12 +0100 Subject: [PATCH 05/10] Add mxnet_backend try block --- pyhf/__init__.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/pyhf/__init__.py b/pyhf/__init__.py index 7286b8ab68..4016a18747 100644 --- a/pyhf/__init__.py +++ b/pyhf/__init__.py @@ -18,6 +18,12 @@ except ImportError: pass +try: + from .tensor.mxnet_backend import mxnet_backend + assert mxnet_backend +except ImportError: + pass + tensorlib = numpy_backend() optimizer = scipy_optimizer() From e6326dd326d7c8415078cf06129e463060b904b0 Mon Sep 17 00:00:00 2001 From: Matthew Feickert Date: Mon, 12 Feb 2018 02:14:11 +0100 Subject: [PATCH 06/10] Add default Jupyter notebook example using MXNet backend --- examples/notebooks/example-mxnet.ipynb | 96 ++++++++++++++++++++++++++ 1 file changed, 96 insertions(+) create mode 100644 examples/notebooks/example-mxnet.ipynb diff --git a/examples/notebooks/example-mxnet.ipynb b/examples/notebooks/example-mxnet.ipynb new file mode 100644 index 0000000000..383245a6c3 --- /dev/null +++ b/examples/notebooks/example-mxnet.ipynb @@ -0,0 +1,96 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Populating the interactive namespace from numpy and matplotlib\n" + ] + } + ], + "source": [ + "%pylab inline" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import pyhf\n", + "from pyhf import hfpdf\n", + "from pyhf.simplemodels import hepdata_like\n", + "import mxnet as mx" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "--------\n", + "as MXNet\n", + "--------\n", + " \n", + "[-22.87784958]\n", + "\n" + ] + } + ], + "source": [ + "source = {\n", + " \"binning\": [2,-0.5,1.5],\n", + " \"bindata\": {\n", + " \"data\": [120.0, 180.0],\n", + " \"bkg\": [100.0, 150.0],\n", + " \"bkgerr\": [10.0, 10.0],\n", + " \"sig\": [30.0, 95.0]\n", + " }\n", + "}\n", + "\n", + "pdf = hepdata_like(source['bindata']['sig'], source['bindata']['bkg'], source['bindata']['bkgerr'])\n", + "data = source['bindata']['data'] + pdf.config.auxdata\n", + "\n", + "init_pars = pdf.config.suggested_init()\n", + "par_bounds = pdf.config.suggested_bounds()\n", + "\n", + "\n", + "print('--------\\nas MXNet\\n--------')\n", + "pyhf.tensorlib = pyhf.mxnet_backend()\n", + "v = pdf.logpdf(init_pars, data)\n", + "print(type(v),v)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.4" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} From 03b3f77aa12e09636d26c6eb5db7e8a2de9734f2 Mon Sep 17 00:00:00 2001 From: Matthew Feickert Date: Mon, 12 Feb 2018 03:08:55 +0100 Subject: [PATCH 07/10] Use list indexing instead of extended iterable unpacking Extended iterable unpacking wasn't added until PEP 3132 and so isn't in Python 2, which is causing the Python 2.7 tests to fail. While this is annoying, it is probably worth still supporting Python 2.7 until the official Python 2 end of life date. c.f. https://www.python.org/dev/peps/pep-3132/ --- pyhf/tensor/mxnet_backend.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyhf/tensor/mxnet_backend.py b/pyhf/tensor/mxnet_backend.py index 398ba7bee1..0b0d55cdcf 100644 --- a/pyhf/tensor/mxnet_backend.py +++ b/pyhf/tensor/mxnet_backend.py @@ -39,9 +39,9 @@ def outer(self, tensor_in_1, tensor_in_2): tensor_1_shape = tensor_in_1.shape tensor_2_shape = tensor_in_2.shape if len(tensor_1_shape) == 1: - tensor_1_shape = (*tensor_1_shape, 1) + tensor_1_shape = (tensor_1_shape[0], 1) if len(tensor_2_shape) == 1: - tensor_2_shape = (*tensor_2_shape, 1) + tensor_2_shape = (tensor_2_shape[0], 1) rows1, cols1 = tensor_1_shape rows2, cols2 = tensor_2_shape From 6fc54365f73a972e135ec3cdb75cf8405e9890e3 Mon Sep 17 00:00:00 2001 From: Matthew Feickert Date: Mon, 12 Feb 2018 11:48:05 +0100 Subject: [PATCH 08/10] Use Conda environment.yml for Binder To support the more complex systems of multiple backends a Conda environment.yml file is used instead of a pip requirements.txt file for use with Binder. c.f. https://mybinder.readthedocs.io/en/latest/sample_repos.html#conda-environment-with-environment-yml --- binder/environment.yml | 111 ++++++++++++++++++++++++++++++++++++++++ binder/requirements.txt | 6 --- 2 files changed, 111 insertions(+), 6 deletions(-) create mode 100644 binder/environment.yml delete mode 100644 binder/requirements.txt diff --git a/binder/environment.yml b/binder/environment.yml new file mode 100644 index 0000000000..7ba02c3af7 --- /dev/null +++ b/binder/environment.yml @@ -0,0 +1,111 @@ +name: pyhf +channels: +- pytorch +- https://conda.anaconda.org/NLeSC +- defaults +dependencies: +- ca-certificates=2017.08.26=h1d4fec5_0 +- certifi=2018.1.18=py36_0 +- cffi=1.11.4=py36h9745a5d_0 +- cudatoolkit=8.0=3 +- cudnn=7.0.5=cuda8.0_0 +- freetype=2.8=hab7d2ae_1 +- intel-openmp=2018.0.0=hc7b2577_8 +- jpeg=9b=h024ee3a_2 +- libedit=3.1=heed3624_0 +- libffi=3.2.1=hd88cf55_4 +- libgcc-ng=7.2.0=h7cc24e2_2 +- libgfortran-ng=7.2.0=h9f7466a_2 +- libpng=1.6.34=hb9fc6fc_0 +- libstdcxx-ng=7.2.0=h7a57d05_2 +- libtiff=4.0.9=h28f6b97_0 +- mkl=2018.0.1=h19d6760_4 +- nccl=1.3.4=cuda8.0_1 +- ncurses=6.0=h9df7e31_2 +- numpy=1.14.0=py36h3dfced4_1 +- olefile=0.45.1=py36_0 +- openssl=1.0.2n=hb7f436b_0 +- pillow=5.0.0=py36h3deb7b8_0 +- pip=9.0.1=py36h6c6f9ce_4 +- pycparser=2.18=py36hf9f622e_1 +- python=3.6.4=hc3d631a_1 +- pytorch=0.3.0=py36cuda8.0cudnn7.0_0 +- readline=7.0=ha6073c6_4 +- setuptools=38.4.0=py36_0 +- six=1.11.0=py36h372c433_1 +- sqlite=3.22.0=h1bed415_0 +- tk=8.6.7=hc745277_3 +- torchvision=0.2.0=py36_0 +- wheel=0.30.0=py36hfd4bba0_1 +- xz=5.2.3=h55aa19d_2 +- zlib=1.2.11=ha838bed_2 +- cuda90=1.0=h6433d27_0 +- pip: + - absl-py>=0.1.10 + - ansiwrap>=0.8.3 + - attrs>=17.4.0 + - bleach>=1.5.0 + - boto3>=1.5.26 + - botocore>=1.8.40 + - chardet>=3.0.4 + - click>=6.7 + - coverage>=4.5.1 + - cycler>=0.10.0 + - decorator>=4.2.1 + - docutils>=0.14 + - entrypoints>=0.2.3 + - future>=0.16.0 + - graphviz>=0.8.1 + - html5lib>=0.9999999 + - idna>=2.6 + - ipython>=6.2.1 + - ipython-genutils>=0.2.0 + - jedi>=0.11.1 + - jinja2>=2.10 + - jmespath>=0.9.3 + - jsonschema>=2.6.0 + - jupyter-client>=5.2.2 + - jupyter-core>=4.4.0 + - markdown>=2.6.11 + - markupsafe>=1.0 + - matplotlib>=2.1.2 + - mistune>=0.8.3 + - mxnet>=1.0.0.post4 + - nbconvert>=5.3.1 + - nbformat>=4.4.0 + - pandas>=0.22.0 + - pandocfilters>=1.4.2 + - papermill>=0.12.2 + - parso>=0.1.1 + - pexpect>=4.4.0 + - pickleshare>=0.7.4 + - pluggy>=0.6.0 + - prompt-toolkit>=1.0.15 + - protobuf>=3.5.1 + - ptyprocess>=0.5.2 + - py>=1.5.2 + - pygments>=2.2.0 + - pyhf>=0.0.4 + - pyparsing>=2.2.0 + - pytest>=3.4.0 + - pytest-cov>=2.5.1 + - python-dateutil>=2.6.1 + - pytz>=2018.3 + - pyyaml>=3.12 + - pyzmq>=17.0.0 + - requests>=2.18.4 + - s3transfer>=0.1.12 + - scipy>=1.0.0 + - simplegeneric>=0.8.1 + - tensorflow>=1.5.0 + - tensorflow-tensorboard>=1.5.1 + - testpath>=0.3.1 + - textwrap3>=0.9.1 + - torch>=0.3.0 + - tornado>=4.5.3 + - tqdm>=4.19.5 + - traitlets>=4.3.2 + - uproot>=2.6.10 + - urllib3>=1.22 + - wcwidth>=0.1.7 + - werkzeug>=0.14.1 diff --git a/binder/requirements.txt b/binder/requirements.txt deleted file mode 100644 index 9079a3c250..0000000000 --- a/binder/requirements.txt +++ /dev/null @@ -1,6 +0,0 @@ -matplotlib -numpy -papermill -pyhf -scipy -uproot From c0c76541b557340612564ced5dc8c077d7f2805e Mon Sep 17 00:00:00 2001 From: Matthew Feickert Date: Mon, 12 Feb 2018 19:25:21 +0100 Subject: [PATCH 09/10] Have simple_broadcast() return a MXNet NDArray Additionally make a correction to outer(). nd.broadcast_mul() is the correct function to use in this method. --- pyhf/tensor/mxnet_backend.py | 49 +++++++++++++++++++++++++----------- 1 file changed, 34 insertions(+), 15 deletions(-) diff --git a/pyhf/tensor/mxnet_backend.py b/pyhf/tensor/mxnet_backend.py index 0b0d55cdcf..92c244aae5 100644 --- a/pyhf/tensor/mxnet_backend.py +++ b/pyhf/tensor/mxnet_backend.py @@ -1,5 +1,7 @@ from mxnet import nd import logging +import math # Required for normal() +from numbers import Number # Required for normal() log = logging.getLogger(__name__) @@ -45,8 +47,8 @@ def outer(self, tensor_in_1, tensor_in_2): rows1, cols1 = tensor_1_shape rows2, cols2 = tensor_2_shape - return nd.reshape(nd.dot(tensor_in_1.reshape((rows1, 1, cols1, 1)), - tensor_in_2.reshape((1, rows2, 1, cols2))), + return nd.reshape(nd.broadcast_mul(tensor_in_1.reshape((rows1, 1, cols1, 1)), + tensor_in_2.reshape((1, rows2, 1, cols2))), (rows1 * cols1, rows2 * cols2)) def astensor(self, tensor_in): @@ -211,11 +213,11 @@ def where(self, mask, tensor_in_1, tensor_in_2): Apply a boolean selection mask to the elements of the input tensors Example: - >>> mxnet_backend.where( - mxnet_backend.astensor([1, 0, 1]), - mxnet_backend.astensor([1, 1, 1]), - mxnet_backend.astensor([2, 2, 2])) - >>> [1. 2. 1.] + >>> where( + astensor([1, 0, 1]), + astensor([1, 1, 1]), + astensor([2, 2, 2])) + [1. 2. 1.] Args: mask: Boolean mask (boolean or tensor object of booleans) @@ -228,7 +230,8 @@ def where(self, mask, tensor_in_1, tensor_in_2): mask = self.astensor(mask) tensor_in_1 = self.astensor(tensor_in_1) tensor_in_2 = self.astensor(tensor_in_2) - return nd.multiply(mask, tensor_in_1) + nd.multiply(nd.subtract(1, mask), tensor_in_2) + return nd.add(nd.multiply(mask, tensor_in_1), + nd.multiply(nd.subtract(1, mask), tensor_in_2)) def concatenate(self, sequence): """ @@ -244,14 +247,32 @@ def concatenate(self, sequence): def simple_broadcast(self, *args): """ - There should be a more MXNet-style way to do this + Broadcast a sequence of 1 dimensional arrays + + Example: + >>> simple_broadcast( + astensor([1]), + astensor([2, 2]), + astensor([3, 3, 3])) + [[1. 1. 1.] + [2. 2. 2.] + [3. 3. 3.]] + + Args: + args: sequence of arrays + + Returns: + MXNet NDArray: The sequence broadcast together """ - broadcast = [] max_dim = max(map(len, args)) + broadcast = [] for arg in args: - broadcast.append(self.astensor(arg) - if len(arg) > 1 else arg * self.ones(max_dim)) - return broadcast + if len(arg) < max_dim: + broadcast.append(nd.broadcast_axis( + arg[0], axis=len(arg.shape) - 1, size=max_dim)) + else: + broadcast.append(arg) + return nd.stack(*broadcast) def poisson(self, n, lam): return self.normal(n, lam, self.sqrt(lam)) @@ -260,8 +281,6 @@ def normal(self, x, mu, sigma): """ Currently copying from PyTorch's source until can find a better way to do this """ - import math - from numbers import Number x = self.astensor(x) mu = self.astensor(mu) sigma = self.astensor(sigma) From e40a007712d2bc061da9d1ddfe31d2532cdcdba3 Mon Sep 17 00:00:00 2001 From: Matthew Feickert Date: Mon, 12 Feb 2018 22:06:57 +0100 Subject: [PATCH 10/10] Add docstrings Add Google Style docstrings that should be compatible with Sphinx under certain themes, such as the Sphinx theme created by Read The Docs. For an example of what this would look like, see PyTorch's docs. c.f.: http://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html#example-google http://pytorch.org/docs/master/_modules/torch/distributions/bernoulli.html#Bernoulli --- pyhf/tensor/mxnet_backend.py | 127 +++++++++++++++++++++-------------- 1 file changed, 77 insertions(+), 50 deletions(-) diff --git a/pyhf/tensor/mxnet_backend.py b/pyhf/tensor/mxnet_backend.py index 92c244aae5..21aa72721b 100644 --- a/pyhf/tensor/mxnet_backend.py +++ b/pyhf/tensor/mxnet_backend.py @@ -6,34 +6,34 @@ class mxnet_backend(object): - """Backend for MXNet""" + """MXNet backend for pyhf""" def __init__(self, **kwargs): pass def tolist(self, tensor_in): """ - Convert a tensor to a list + Convert a tensor to a list. Args: - tensor_in: MXNet tensor + tensor_in (Tensor): Input MXNet tensor Returns: - The possibly nested list of tensor elements. + list: The possibly nested list of tensor elements. """ tensor_in = self.astensor(tensor_in) return tensor_in.asnumpy().tolist() def outer(self, tensor_in_1, tensor_in_2): """ - The outer product of two tensors + The outer product of two tensors. Args: - tensor_in_1: tensor object - tensor_in_2: tensor object + tensor_in_1 (Tensor): Tensor object + tensor_in_2 (Tensor): Tensor object Returns: - MXNet NDArray: The outer product + MXNet NDArray: The outer product. """ tensor_in_1 = self.astensor(tensor_in_1) tensor_in_2 = self.astensor(tensor_in_2) @@ -53,13 +53,13 @@ def outer(self, tensor_in_1, tensor_in_2): def astensor(self, tensor_in): """ - Convert a tensor to an MXNet NDArray + Convert a tensor to an MXNet NDArray. Args: - tensor_in: tensor object + tensor_in (Number or Tensor): Tensor object Returns: - MXNet NDArray: a multi-dimensional, fixed-size homogenous array. + MXNet NDArray: A multi-dimensional, fixed-size homogenous array. """ return nd.array(tensor_in) @@ -68,11 +68,11 @@ def sum(self, tensor_in, axis=None): Compute the sum of array elements over given axes. Args: - tensor_in: tensor object - axis: the axes over which to sum + tensor_in (Tensor): Tensor object + axis (Number): The axes over which to sum Returns: - MXNet NDArray: ndarray of the sum over the axes + MXNet NDArray: ndarray of the sum over the axes. """ tensor_in = self.astensor(tensor_in) if axis is None or tensor_in.shape == nd.array([]).size: @@ -85,11 +85,11 @@ def product(self, tensor_in, axis=None): Product of array elements over given axes. Args: - tensor_in: tensor object - axis: the axes over which to take the product + tensor_in (Tensor): Tensor object + axis (Number): The axes over which to take the product Returns: - MXNet NDArray: ndarray of the product over the axes + MXNet NDArray: ndarray of the product over the axes. """ tensor_in = self.astensor(tensor_in) if axis is None: @@ -102,10 +102,10 @@ def ones(self, shape): A new array filled with all ones, with the given shape. Args: - shape: the shape of the array + shape (Number): The shape of the array Returns: - MXNet NDArray: ndarray of 1's with given shape + MXNet NDArray: ndarray of 1's with given shape. """ return nd.ones(shape) @@ -114,10 +114,10 @@ def zeros(self, shape): A new array filled with all zeros, with the given shape. Args: - shape: the shape of the array + shape (Number): The shape of the array Returns: - MXNet NDArray: ndarray of 0's with given shape + MXNet NDArray: ndarray of 0's with given shape. """ return nd.zeros(shape) @@ -127,11 +127,11 @@ def power(self, tensor_in_1, tensor_in_2): element-wise with broadcasting. Args: - tensor_in_1: tensor object - tensor_in_2: tensor object + tensor_in_1 (Tensor): Tensor object + tensor_in_2 (Tensor): Tensor object Returns: - MXNet NDArray: first array elements raised to powers from second array + MXNet NDArray: First array elements raised to powers from second array. """ tensor_in_1 = self.astensor(tensor_in_1) tensor_in_2 = self.astensor(tensor_in_2) @@ -142,10 +142,10 @@ def sqrt(self, tensor_in): Element-wise square-root value of the input. Args: - tensor_in: tensor object + tensor_in (Tensor): Tensor object Returns: - MXNet NDArray: element-wise square-root value + MXNet NDArray: Element-wise square-root value. """ tensor_in = self.astensor(tensor_in) return nd.sqrt(tensor_in) @@ -155,11 +155,11 @@ def divide(self, tensor_in_1, tensor_in_2): Element-wise division of the input arrays with broadcasting. Args: - tensor_in_1: tensor object - tensor_in_2: tensor object + tensor_in_1 (Tensor): Tensor object + tensor_in_2 (Tensor): Tensor object Returns: - MXNet NDArray: element-wise division of the input arrays + MXNet NDArray: Element-wise division of the input arrays. """ tensor_in_1 = self.astensor(tensor_in_1) tensor_in_2 = self.astensor(tensor_in_2) @@ -170,10 +170,10 @@ def log(self, tensor_in): Element-wise Natural logarithmic value of the input. Args: - tensor_in: tensor object + tensor_in (Tensor): Tensor object Returns: - MXNet NDArray: element-wise Natural logarithmic value + MXNet NDArray: Element-wise Natural logarithmic value. """ tensor_in = self.astensor(tensor_in) return nd.log(tensor_in) @@ -183,10 +183,10 @@ def exp(self, tensor_in): Element-wise exponential value of the input. Args: - tensor_in: tensor object + tensor_in (Tensor): Tensor object Returns: - MXNet NDArray: element-wise exponential value + MXNet NDArray: Element-wise exponential value. """ tensor_in = self.astensor(tensor_in) return nd.exp(tensor_in) @@ -200,19 +200,20 @@ def stack(self, sequence, axis=0): if axis=-1 it will be the last dimension. Args: - sequence: sequence of arrays - axis: the axis along which to join the arrays + sequence (Array of Tensors): Sequence of arrays + axis (Number): The axis along which to join the arrays Returns: - MXNet NDArray: ndarray comprised of the elements of the sequence + MXNet NDArray: ndarray comprised of the elements of the sequence. """ return nd.stack(*sequence, axis=axis) def where(self, mask, tensor_in_1, tensor_in_2): """ - Apply a boolean selection mask to the elements of the input tensors + Apply a boolean selection mask to the elements of the input tensors. + + Example:: - Example: >>> where( astensor([1, 0, 1]), astensor([1, 1, 1]), @@ -220,12 +221,12 @@ def where(self, mask, tensor_in_1, tensor_in_2): [1. 2. 1.] Args: - mask: Boolean mask (boolean or tensor object of booleans) - tensor_in_1: tensor object - tensor_in_2: tensor object + mask (bool): Boolean mask (boolean or tensor object of booleans) + tensor_in_1 (Tensor): Tensor object + tensor_in_2 (Tensor): Tensor object Returns: - MXNet NDArray: The result of the mask being applied to the tensors + MXNet NDArray: The result of the mask being applied to the tensors. """ mask = self.astensor(mask) tensor_in_1 = self.astensor(tensor_in_1) @@ -235,21 +236,22 @@ def where(self, mask, tensor_in_1, tensor_in_2): def concatenate(self, sequence): """ - Join the elements of the sequence + Join the elements of the sequence. Args: - sequence: the sequence of arrays to join + sequence (Array of Tensors): The sequence of arrays to join Returns: - MXNet NDArray: The ndarray of the joined elements + MXNet NDArray: The ndarray of the joined elements. """ return nd.concat(*sequence, dim=0) def simple_broadcast(self, *args): """ - Broadcast a sequence of 1 dimensional arrays + Broadcast a sequence of 1 dimensional arrays. + + Example:: - Example: >>> simple_broadcast( astensor([1]), astensor([2, 2]), @@ -259,10 +261,10 @@ def simple_broadcast(self, *args): [3. 3. 3.]] Args: - args: sequence of arrays + args (Array of Tensors): Sequence of arrays Returns: - MXNet NDArray: The sequence broadcast together + MXNet NDArray: The sequence broadcast together. """ max_dim = max(map(len, args)) broadcast = [] @@ -275,16 +277,41 @@ def simple_broadcast(self, *args): return nd.stack(*broadcast) def poisson(self, n, lam): + """ + The continous approximation to the probability density function of the Poisson + distribution given the parameters evaluated at `n`. + + Args: + n (Number or Tensor): The value at which to evaluate the Poisson distribution p.d.f. + (the observed number of events) + lam (Number or Tensor): The mean of the Poisson distribution p.d.f. + (the expected number of events) + + Returns: + MXNet NDArray: Value of N(n|lam, sqrt(lam)), the continous approximation to Poisson(n|lam). + """ return self.normal(n, lam, self.sqrt(lam)) def normal(self, x, mu, sigma): """ - Currently copying from PyTorch's source until can find a better way to do this + The probability density function of the Normal distribution given the parameters + evaluated at `x`. + + Args: + x (Number or Tensor): The point at which to evaluate the Normal distribution p.d.f. + mu (Number or Tensor): The mean of the Normal distribution p.d.f. + sigma(Number or Tensor): The standard deviation of the Normal distribution p.d.f. + + Returns: + MXNet NDArray: Value of N(x|mu, sigma). """ x = self.astensor(x) mu = self.astensor(mu) sigma = self.astensor(sigma) + # This is currently copied directly from PyTorch's source until a better + # way can be found to do this in MXNet + # https://github.com/pytorch/pytorch/blob/master/torch/distributions/normal.py#L61-L66 def log_prob(value, loc, scale): variance = scale ** 2 log_scale = math.log(scale) if isinstance(