-
Notifications
You must be signed in to change notification settings - Fork 81
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add MXNet backend #83
Changes from 1 commit
faac9eb
0e8b5e4
f540dde
bf92334
9fd4d01
e6326dd
03b3f77
6fc5436
c0c7654
e40a007
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,183 @@ | ||
import mxnet as mx | ||
from mxnet import nd | ||
import logging | ||
log = logging.getLogger(__name__) | ||
|
||
|
||
class mxnet_backend(object): | ||
"""Backend for MXNet""" | ||
|
||
def __init__(self, **kwargs): | ||
self.session = kwargs.get('session') | ||
|
||
def tolist(self, tensor_in): | ||
""" | ||
Convert a tensor to a list | ||
|
||
Args: | ||
tensor_in: MXNet tensor | ||
|
||
Returns: | ||
The possibly nested list of tensor elements. | ||
""" | ||
tensor_in = self.astensor(tensor_in) | ||
return tensor_in.asnumpy().tolist() | ||
|
||
def outer(self, tensor_in_1, tensor_in_2): | ||
""" | ||
The outer product of two tensors: u.v^T | ||
|
||
Args: | ||
tensor_in_1: tensor object | ||
tensor_in_2: tensor object | ||
|
||
Returns: | ||
MXNet NDArray: The outer product | ||
""" | ||
tensor_in_1 = self.astensor(tensor_in_1) | ||
tensor_in_2 = self.astensor(tensor_in_2) | ||
pass | ||
|
||
def astensor(self, tensor_in): | ||
""" | ||
Convert a tensor to an MXNet NDArray | ||
|
||
Args: | ||
tensor_in: tensor object | ||
|
||
Returns: | ||
MXNet NDArray: a multi-dimensional, fixed-size homogenous array. | ||
""" | ||
return nd.array(tensor_in) | ||
|
||
def sum(self, tensor_in, axis=None): | ||
""" | ||
Compute the sum of array elements over given axes. | ||
|
||
Args: | ||
tensor_in: tensor object | ||
axis: the axes over which to sum | ||
|
||
Returns: | ||
MXNet NDArray: ndarray of the sum over the axes | ||
""" | ||
tensor_in = self.astensor(tensor_in) | ||
if axis is None or tensor_in.shape == nd.Size([]): | ||
return nd.sum(tensor_in) | ||
else: | ||
return nd.sum(tensor_in, axis) | ||
|
||
def product(self, tensor_in, axis=None): | ||
pass | ||
|
||
def ones(self, shape): | ||
""" | ||
A new array filled with all ones, with the given shape. | ||
|
||
Args: | ||
shape: the shape of the array | ||
|
||
Returns: | ||
MXNet NDArray: ndarray of 1's with given shape | ||
""" | ||
return nd.ones(shape) | ||
|
||
def zeros(self, shape): | ||
""" | ||
A new array filled with all zeros, with the given shape. | ||
|
||
Args: | ||
shape: the shape of the array | ||
|
||
Returns: | ||
MXNet NDArray: ndarray of 0's with given shape | ||
""" | ||
return nd.zeros(shape) | ||
|
||
def power(self, tensor_in_1, tensor_in_2): | ||
""" | ||
Result of first array elements raised to powers from second array, | ||
element-wise with broadcasting. | ||
|
||
Args: | ||
tensor_in_1: tensor object | ||
tensor_in_2: tensor object | ||
|
||
Returns: | ||
MXNet NDArray: first array elements raised to powers from second array | ||
""" | ||
tensor_in_1 = self.astensor(tensor_in_1) | ||
tensor_in_2 = self.astensor(tensor_in_2) | ||
return nd.power(tensor_in_1, tensor_in_2) | ||
|
||
def sqrt(self, tensor_in): | ||
""" | ||
Element-wise square-root value of the input. | ||
|
||
Args: | ||
tensor_in: tensor object | ||
|
||
Returns: | ||
MXNet NDArray: element-wise square-root value | ||
""" | ||
tensor_in = self.astensor(tensor_in) | ||
return nd.sqrt(tensor_in) | ||
|
||
def divide(self, tensor_in_1, tensor_in_2): | ||
""" | ||
Element-wise division of the input arrays with broadcasting. | ||
|
||
Args: | ||
tensor_in_1: tensor object | ||
tensor_in_2: tensor object | ||
|
||
Returns: | ||
MXNet NDArray: element-wise division of the input arrays | ||
""" | ||
tensor_in_1 = self.astensor(tensor_in_1) | ||
tensor_in_2 = self.astensor(tensor_in_2) | ||
return nd.divide(tensor_in_1, tensor_in_2) | ||
|
||
def log(self, tensor_in): | ||
""" | ||
Element-wise Natural logarithmic value of the input. | ||
|
||
Args: | ||
tensor_in: tensor object | ||
|
||
Returns: | ||
MXNet NDArray: element-wise Natural logarithmic value | ||
""" | ||
tensor_in = self.astensor(tensor_in) | ||
return nd.log(tensor_in) | ||
|
||
def exp(self, tensor_in): | ||
""" | ||
Element-wise exponential value of the input. | ||
|
||
Args: | ||
tensor_in: tensor object | ||
|
||
Returns: | ||
MXNet NDArray: element-wise exponential value | ||
""" | ||
tensor_in = self.astensor(tensor_in) | ||
return nd.exp(tensor_in) | ||
|
||
def stack(self, sequence, axis=0): | ||
pass | ||
|
||
def where(self, mask, tensor_in_1, tensor_in_2): | ||
pass | ||
|
||
def concatenate(self, sequence): | ||
pass | ||
|
||
def simple_broadcast(self, *args): | ||
pass | ||
|
||
def poisson(self, n, lam): | ||
pass | ||
|
||
def normal(self, x, mu, sigma): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can probably use these APIs https://mxnet.incubator.apache.org/api/python/ndarray/random.html There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. but for poisson, let's stick for now to a gaussian approximation with mu = lambda , sigma = sqrt(lambda) since we need continuous poissons There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hm, maybe I'm misunderstanding what you mean, but we want the pdf of the distribution evaluated at Yeah, I follow you RE: the Poissons. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ah right, i didn't look to closely at the api seems like it more close to np.random stan scipy.stats.. sorry about that! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No worries! Just wanted to make sure I wasn't being dumb. :) |
||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seems like MXNet does not have a concept of a session (at least I can't find a reference to it. So I think this can be removed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, that was the result of some sloppy copypasta and me not checking things late at night. I'll fix that up when I get back to this after finishing work tonight.