Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MXNet backend #83

Merged
merged 10 commits into from
Feb 12, 2018
263 changes: 263 additions & 0 deletions pyhf/tensor/mxnet_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,263 @@
import mxnet as mx
from mxnet import nd
import logging
log = logging.getLogger(__name__)


class mxnet_backend(object):
"""Backend for MXNet"""

def __init__(self, **kwargs):
pass

def tolist(self, tensor_in):
"""
Convert a tensor to a list

Args:
tensor_in: MXNet tensor

Returns:
The possibly nested list of tensor elements.
"""
tensor_in = self.astensor(tensor_in)
return tensor_in.asnumpy().tolist()

def outer(self, tensor_in_1, tensor_in_2):
"""
The outer product of two tensors: u.v^T

Args:
tensor_in_1: tensor object
tensor_in_2: tensor object

Returns:
MXNet NDArray: The outer product
"""
tensor_in_1 = self.astensor(tensor_in_1)
tensor_in_2 = self.astensor(tensor_in_2)
pass

def astensor(self, tensor_in):
"""
Convert a tensor to an MXNet NDArray

Args:
tensor_in: tensor object

Returns:
MXNet NDArray: a multi-dimensional, fixed-size homogenous array.
"""
return nd.array(tensor_in)

def sum(self, tensor_in, axis=None):
"""
Compute the sum of array elements over given axes.

Args:
tensor_in: tensor object
axis: the axes over which to sum

Returns:
MXNet NDArray: ndarray of the sum over the axes
"""
tensor_in = self.astensor(tensor_in)
if axis is None or tensor_in.shape == nd.array([]).size:
return nd.sum(tensor_in)
else:
return nd.sum(tensor_in, axis)

def product(self, tensor_in, axis=None):
"""
Product of array elements over given axes.

Args:
tensor_in: tensor object
axis: the axes over which to take the product

Returns:
MXNet NDArray: ndarray of the product over the axes
"""
tensor_in = self.astensor(tensor_in)
if axis is None:
return nd.prod(tensor_in)
else:
return nd.prod(tensor_in, axis)

def ones(self, shape):
"""
A new array filled with all ones, with the given shape.

Args:
shape: the shape of the array

Returns:
MXNet NDArray: ndarray of 1's with given shape
"""
return nd.ones(shape)

def zeros(self, shape):
"""
A new array filled with all zeros, with the given shape.

Args:
shape: the shape of the array

Returns:
MXNet NDArray: ndarray of 0's with given shape
"""
return nd.zeros(shape)

def power(self, tensor_in_1, tensor_in_2):
"""
Result of first array elements raised to powers from second array,
element-wise with broadcasting.

Args:
tensor_in_1: tensor object
tensor_in_2: tensor object

Returns:
MXNet NDArray: first array elements raised to powers from second array
"""
tensor_in_1 = self.astensor(tensor_in_1)
tensor_in_2 = self.astensor(tensor_in_2)
return nd.power(tensor_in_1, tensor_in_2)

def sqrt(self, tensor_in):
"""
Element-wise square-root value of the input.

Args:
tensor_in: tensor object

Returns:
MXNet NDArray: element-wise square-root value
"""
tensor_in = self.astensor(tensor_in)
return nd.sqrt(tensor_in)

def divide(self, tensor_in_1, tensor_in_2):
"""
Element-wise division of the input arrays with broadcasting.

Args:
tensor_in_1: tensor object
tensor_in_2: tensor object

Returns:
MXNet NDArray: element-wise division of the input arrays
"""
tensor_in_1 = self.astensor(tensor_in_1)
tensor_in_2 = self.astensor(tensor_in_2)
return nd.divide(tensor_in_1, tensor_in_2)

def log(self, tensor_in):
"""
Element-wise Natural logarithmic value of the input.

Args:
tensor_in: tensor object

Returns:
MXNet NDArray: element-wise Natural logarithmic value
"""
tensor_in = self.astensor(tensor_in)
return nd.log(tensor_in)

def exp(self, tensor_in):
"""
Element-wise exponential value of the input.

Args:
tensor_in: tensor object

Returns:
MXNet NDArray: element-wise exponential value
"""
tensor_in = self.astensor(tensor_in)
return nd.exp(tensor_in)

def stack(self, sequence, axis=0):
"""
Join a sequence of arrays along a new axis.

The axis parameter specifies the index of the new axis in the dimensions
of the result. For example, if axis=0 it will be the first dimension and
if axis=-1 it will be the last dimension.

Args:
sequence: sequence of arrays
axis: the axis along which to join the arrays

Returns:
MXNet NDArray: ndarray comprised of the elements of the sequence
"""
return nd.stack(*sequence, axis=axis)

def where(self, mask, tensor_in_1, tensor_in_2):
"""
Apply a boolean selection mask to the elements of the input tensors

Example:
>>> mxnet_backend.where(
mxnet_backend.astensor([1, 0, 1]),
mxnet_backend.astensor([1, 1, 1]),
mxnet_backend.astensor([2, 2, 2]))
>>> [1. 2. 1.]

Args:
mask: Boolean mask (boolean or tensor object of booleans)
tensor_in_1: tensor object
tensor_in_2: tensor object

Returns:
MXNet NDArray: The result of the mask being applied to the tensors
"""
mask = self.astensor(mask)
tensor_in_1 = self.astensor(tensor_in_1)
tensor_in_2 = self.astensor(tensor_in_2)
return nd.multiply(mask, tensor_in_1) + nd.multiply(nd.subtract(1, mask), tensor_in_2)

def concatenate(self, sequence):
"""
Join the elements of the sequence

Args:
sequence: the sequence of arrays to join

Returns:
MXNet NDArray: The ndarray of the joined elements
"""
return nd.concat(*sequence, dim=0)

def simple_broadcast(self, *args):
"""
Does this work?
There should be a more MXNet-style way to do this
"""
broadcast = []
max_dim = max(map(len, args))
for arg in args:
broadcast.append(self.astensor(arg)
if len(arg) > 1 else arg * self.ones(max_dim))
return broadcast

def poisson(self, n, lam):
return self.normal(n, lam, self.sqrt(lam))

def normal(self, x, mu, sigma):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but for poisson, let's stick for now to a gaussian approximation with mu = lambda , sigma = sqrt(lambda) since we need continuous poissons

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, maybe I'm misunderstanding what you mean, but we want the pdf of the distribution evaluated at x, yes? The MXNet distribution generators return sampling of the distributions given the parameters you pass, so mx.nd.random.normal() is akin to np.random.normal(), and we want something like scipy.stats.norm.pdf(x), right? I might have missed another API though.

Yeah, I follow you RE: the Poissons.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah right, i didn't look to closely at the api seems like it more close to np.random stan scipy.stats.. sorry about that!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No worries! Just wanted to make sure I wasn't being dumb. :)

import math
from numbers import Number
x = self.astensor(x)
mu = self.astensor(mu)
sigma = self.astensor(sigma)
# Is needed?
# normal = nd.random.normal(loc=mu, scale=sigma)

def log_prob(value, loc, scale):
variance = scale ** 2
log_scale = math.log(scale) if isinstance(
scale, Number) else scale.log()
return -((value - loc) ** 2) / (2 * variance) - log_scale - math.log(math.sqrt(2 * math.pi))
return self.exp(log_prob(x, mu, sigma))
7 changes: 6 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
'torch': [
'torch'
],
'mxnet':[
'mxnet',
],
'develop': [
'pyflakes',
'pytest>=3.2.0',
Expand All @@ -29,7 +32,9 @@
'uproot',
'papermill',
'torch',
'tensorflow'
'tensorflow',
'mxnet>=1.0.0',
'graphviz'
]
},
entry_points = {
Expand Down