Skip to content

Commit

Permalink
Add MXNet backend (#83)
Browse files Browse the repository at this point in the history
* Add placeholder MXNet backend

Add a placeholder file for the MXNet backend and partially fill it in.

In addition also add mxnet to packages to pip install and add
mxnet_backend to the list of backends to test.

* Add MXNet backend stack() and concatenate()

Other methods are added as well, but these are the two that were the
most important

* Add placeholder for outer()

The current outer() function is highly suboptimal. However, until a
better solution can be implimented in a more MXNet-centric style it will
at least pass the tests. The Pull Request in which this commit exists
should _NOT_ be merged in until outer() is updated with a better
solution.

* Use MXNet-centric implimentation of outer()

Use the linear algebra functions of MXNet to provide a more clean
implimenation of the outer product method, outer()

* Add mxnet_backend try block

* Add default Jupyter notebook example using MXNet backend

* Use list indexing instead of extended iterable unpacking

Extended iterable unpacking wasn't added until PEP 3132 and so isn't in
Python 2, which is causing the Python 2.7 tests to fail. While this is
annoying, it is probably worth still supporting Python 2.7 until the
official Python 2 end of life date.

c.f. https://www.python.org/dev/peps/pep-3132/

* Use Conda environment.yml for Binder

To support the more complex systems of multiple backends a Conda
environment.yml file is used instead of a pip requirements.txt file for
use with Binder.

c.f.
https://mybinder.readthedocs.io/en/latest/sample_repos.html#conda-environment-with-environment-yml

* Have simple_broadcast() return a MXNet NDArray

Additionally make a correction to outer().
nd.broadcast_mul() is the correct function to use in this method.

* Add docstrings

Add Google Style docstrings that should be compatible with Sphinx under
certain themes, such as the Sphinx theme created by Read The Docs. For
an example of what this would look like, see PyTorch's docs.

c.f.:
  http://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html#example-google
  http://pytorch.org/docs/master/_modules/torch/distributions/bernoulli.html#Bernoulli
  • Loading branch information
matthewfeickert authored and lukasheinrich committed Feb 12, 2018
1 parent 3b27375 commit 9f8c4de
Show file tree
Hide file tree
Showing 7 changed files with 591 additions and 48 deletions.
111 changes: 111 additions & 0 deletions binder/environment.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
name: pyhf
channels:
- pytorch
- https://conda.anaconda.org/NLeSC
- defaults
dependencies:
- ca-certificates=2017.08.26=h1d4fec5_0
- certifi=2018.1.18=py36_0
- cffi=1.11.4=py36h9745a5d_0
- cudatoolkit=8.0=3
- cudnn=7.0.5=cuda8.0_0
- freetype=2.8=hab7d2ae_1
- intel-openmp=2018.0.0=hc7b2577_8
- jpeg=9b=h024ee3a_2
- libedit=3.1=heed3624_0
- libffi=3.2.1=hd88cf55_4
- libgcc-ng=7.2.0=h7cc24e2_2
- libgfortran-ng=7.2.0=h9f7466a_2
- libpng=1.6.34=hb9fc6fc_0
- libstdcxx-ng=7.2.0=h7a57d05_2
- libtiff=4.0.9=h28f6b97_0
- mkl=2018.0.1=h19d6760_4
- nccl=1.3.4=cuda8.0_1
- ncurses=6.0=h9df7e31_2
- numpy=1.14.0=py36h3dfced4_1
- olefile=0.45.1=py36_0
- openssl=1.0.2n=hb7f436b_0
- pillow=5.0.0=py36h3deb7b8_0
- pip=9.0.1=py36h6c6f9ce_4
- pycparser=2.18=py36hf9f622e_1
- python=3.6.4=hc3d631a_1
- pytorch=0.3.0=py36cuda8.0cudnn7.0_0
- readline=7.0=ha6073c6_4
- setuptools=38.4.0=py36_0
- six=1.11.0=py36h372c433_1
- sqlite=3.22.0=h1bed415_0
- tk=8.6.7=hc745277_3
- torchvision=0.2.0=py36_0
- wheel=0.30.0=py36hfd4bba0_1
- xz=5.2.3=h55aa19d_2
- zlib=1.2.11=ha838bed_2
- cuda90=1.0=h6433d27_0
- pip:
- absl-py>=0.1.10
- ansiwrap>=0.8.3
- attrs>=17.4.0
- bleach>=1.5.0
- boto3>=1.5.26
- botocore>=1.8.40
- chardet>=3.0.4
- click>=6.7
- coverage>=4.5.1
- cycler>=0.10.0
- decorator>=4.2.1
- docutils>=0.14
- entrypoints>=0.2.3
- future>=0.16.0
- graphviz>=0.8.1
- html5lib>=0.9999999
- idna>=2.6
- ipython>=6.2.1
- ipython-genutils>=0.2.0
- jedi>=0.11.1
- jinja2>=2.10
- jmespath>=0.9.3
- jsonschema>=2.6.0
- jupyter-client>=5.2.2
- jupyter-core>=4.4.0
- markdown>=2.6.11
- markupsafe>=1.0
- matplotlib>=2.1.2
- mistune>=0.8.3
- mxnet>=1.0.0.post4
- nbconvert>=5.3.1
- nbformat>=4.4.0
- pandas>=0.22.0
- pandocfilters>=1.4.2
- papermill>=0.12.2
- parso>=0.1.1
- pexpect>=4.4.0
- pickleshare>=0.7.4
- pluggy>=0.6.0
- prompt-toolkit>=1.0.15
- protobuf>=3.5.1
- ptyprocess>=0.5.2
- py>=1.5.2
- pygments>=2.2.0
- pyhf>=0.0.4
- pyparsing>=2.2.0
- pytest>=3.4.0
- pytest-cov>=2.5.1
- python-dateutil>=2.6.1
- pytz>=2018.3
- pyyaml>=3.12
- pyzmq>=17.0.0
- requests>=2.18.4
- s3transfer>=0.1.12
- scipy>=1.0.0
- simplegeneric>=0.8.1
- tensorflow>=1.5.0
- tensorflow-tensorboard>=1.5.1
- testpath>=0.3.1
- textwrap3>=0.9.1
- torch>=0.3.0
- tornado>=4.5.3
- tqdm>=4.19.5
- traitlets>=4.3.2
- uproot>=2.6.10
- urllib3>=1.22
- wcwidth>=0.1.7
- werkzeug>=0.14.1
6 changes: 0 additions & 6 deletions binder/requirements.txt

This file was deleted.

96 changes: 96 additions & 0 deletions examples/notebooks/example-mxnet.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Populating the interactive namespace from numpy and matplotlib\n"
]
}
],
"source": [
"%pylab inline"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import pyhf\n",
"from pyhf import hfpdf\n",
"from pyhf.simplemodels import hepdata_like\n",
"import mxnet as mx"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"--------\n",
"as MXNet\n",
"--------\n",
"<class 'mxnet.ndarray.ndarray.NDArray'> \n",
"[-22.87784958]\n",
"<NDArray 1 @cpu(0)>\n"
]
}
],
"source": [
"source = {\n",
" \"binning\": [2,-0.5,1.5],\n",
" \"bindata\": {\n",
" \"data\": [120.0, 180.0],\n",
" \"bkg\": [100.0, 150.0],\n",
" \"bkgerr\": [10.0, 10.0],\n",
" \"sig\": [30.0, 95.0]\n",
" }\n",
"}\n",
"\n",
"pdf = hepdata_like(source['bindata']['sig'], source['bindata']['bkg'], source['bindata']['bkgerr'])\n",
"data = source['bindata']['data'] + pdf.config.auxdata\n",
"\n",
"init_pars = pdf.config.suggested_init()\n",
"par_bounds = pdf.config.suggested_bounds()\n",
"\n",
"\n",
"print('--------\\nas MXNet\\n--------')\n",
"pyhf.tensorlib = pyhf.mxnet_backend()\n",
"v = pdf.logpdf(init_pars, data)\n",
"print(type(v),v)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
6 changes: 6 additions & 0 deletions pyhf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@
except ImportError:
pass

try:
from .tensor.mxnet_backend import mxnet_backend
assert mxnet_backend
except ImportError:
pass


tensorlib = numpy_backend()
optimizer = scipy_optimizer()
Expand Down

0 comments on commit 9f8c4de

Please sign in to comment.