Skip to content

Commit

Permalink
Have simple_broadcast() return a MXNet NDArray
Browse files Browse the repository at this point in the history
  • Loading branch information
matthewfeickert committed Feb 12, 2018
1 parent 6fc5436 commit 7fc9d37
Showing 1 changed file with 32 additions and 13 deletions.
45 changes: 32 additions & 13 deletions pyhf/tensor/mxnet_backend.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from mxnet import nd
import logging
import math # Required for normal()
from numbers import Number # Required for normal()
log = logging.getLogger(__name__)


Expand Down Expand Up @@ -211,11 +213,11 @@ def where(self, mask, tensor_in_1, tensor_in_2):
Apply a boolean selection mask to the elements of the input tensors
Example:
>>> mxnet_backend.where(
mxnet_backend.astensor([1, 0, 1]),
mxnet_backend.astensor([1, 1, 1]),
mxnet_backend.astensor([2, 2, 2]))
>>> [1. 2. 1.]
>>> where(
astensor([1, 0, 1]),
astensor([1, 1, 1]),
astensor([2, 2, 2]))
[1. 2. 1.]
Args:
mask: Boolean mask (boolean or tensor object of booleans)
Expand All @@ -228,7 +230,8 @@ def where(self, mask, tensor_in_1, tensor_in_2):
mask = self.astensor(mask)
tensor_in_1 = self.astensor(tensor_in_1)
tensor_in_2 = self.astensor(tensor_in_2)
return nd.multiply(mask, tensor_in_1) + nd.multiply(nd.subtract(1, mask), tensor_in_2)
return nd.add(nd.multiply(mask, tensor_in_1),
nd.multiply(nd.subtract(1, mask), tensor_in_2))

def concatenate(self, sequence):
"""
Expand All @@ -244,14 +247,32 @@ def concatenate(self, sequence):

def simple_broadcast(self, *args):
"""
There should be a more MXNet-style way to do this
Broadcast a sequence of 1 dimensional arrays
Example:
>>> simple_broadcast(
astensor([1]),
astensor([2, 2]),
astensor([3, 3, 3]))
[[1. 1. 1.]
[2. 2. 2.]
[3. 3. 3.]]
Args:
args: sequence of arrays
Returns:
MXNet NDArray: The sequence broadcast together
"""
broadcast = []
max_dim = max(map(len, args))
broadcast = []
for arg in args:
broadcast.append(self.astensor(arg)
if len(arg) > 1 else arg * self.ones(max_dim))
return broadcast
if len(arg) < max_dim:
broadcast.append(nd.broadcast_axis(
arg[0], axis=len(arg.shape) - 1, size=max_dim))
else:
broadcast.append(arg)
return nd.stack(*broadcast)

def poisson(self, n, lam):
return self.normal(n, lam, self.sqrt(lam))
Expand All @@ -260,8 +281,6 @@ def normal(self, x, mu, sigma):
"""
Currently copying from PyTorch's source until can find a better way to do this
"""
import math
from numbers import Number
x = self.astensor(x)
mu = self.astensor(mu)
sigma = self.astensor(sigma)
Expand Down

0 comments on commit 7fc9d37

Please sign in to comment.