Skip to content

Commit

Permalink
Use MXNet-centric implimentation of outer()
Browse files Browse the repository at this point in the history
Use the linear algebra functions of MXNet to provide a more clean
implimenation of the outer product method, outer()
  • Loading branch information
matthewfeickert committed Feb 12, 2018
1 parent f540dde commit bf92334
Showing 1 changed file with 14 additions and 10 deletions.
24 changes: 14 additions & 10 deletions pyhf/tensor/mxnet_backend.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from mxnet import nd
import logging
import itertools # Hack fix for mxnet_backend.outer()
log = logging.getLogger(__name__)


Expand All @@ -25,7 +24,7 @@ def tolist(self, tensor_in):

def outer(self, tensor_in_1, tensor_in_2):
"""
The outer product of two tensors: u.v^T
The outer product of two tensors
Args:
tensor_in_1: tensor object
Expand All @@ -34,16 +33,21 @@ def outer(self, tensor_in_1, tensor_in_2):
Returns:
MXNet NDArray: The outer product
"""
# This is currently a rather stupid way to do things, so need to fix this
# Currently also is assuming only 1-d tensors, which is bad
tensor_in_1 = self.astensor(tensor_in_1)
tensor_in_2 = self.astensor(tensor_in_2)
tensor_in_2 = tensor_in_2.T
outer = nd.ones((tensor_in_1.size, tensor_in_2.size))
for i, j in itertools.product(range(tensor_in_1.size),
range(tensor_in_2.size)):
outer[i, j] = nd.dot(tensor_in_1[i], tensor_in_2[j])
return outer

tensor_1_shape = tensor_in_1.shape
tensor_2_shape = tensor_in_2.shape
if len(tensor_1_shape) == 1:
tensor_1_shape = (*tensor_1_shape, 1)
if len(tensor_2_shape) == 1:
tensor_2_shape = (*tensor_2_shape, 1)

rows1, cols1 = tensor_1_shape
rows2, cols2 = tensor_2_shape
return nd.reshape(nd.dot(tensor_in_1.reshape((rows1, 1, cols1, 1)),
tensor_in_2.reshape((1, rows2, 1, cols2))),
(rows1 * cols1, rows2 * cols2))

def astensor(self, tensor_in):
"""
Expand Down

0 comments on commit bf92334

Please sign in to comment.