Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow passing of kwargs to theano.function. #2358

Merged
merged 8 commits into from Aug 10, 2013
26 changes: 22 additions & 4 deletions sympy/printing/tests/test_theanocode.py
@@ -1,4 +1,5 @@
from sympy.external import import_module
from sympy.utilities.pytest import raises

theano = import_module('theano')
if theano:
Expand Down Expand Up @@ -135,10 +136,10 @@ def test_theano_function_simple():
f = theano_function([x, y], [x+y])
assert f(2, 3) == 5


def test_theano_function_numpy():
import numpy as np
f = theano_function([x, y], [x+y], dim=1)
f = theano_function([x, y], [x+y], dim=1,
dtypes={x: 'float64', y: 'float64'})
assert np.linalg.norm(f([1, 2], [3, 4]) - np.asarray([4, 6])) < 1e-9

f = theano_function([x, y], [x+y], dtypes={x: 'float64', y: 'float64'},
Expand All @@ -147,6 +148,20 @@ def test_theano_function_numpy():
yy = 2*np.arange(3).astype('float64')
assert np.linalg.norm(f(xx, yy) - 3*np.arange(3)) < 1e-9

def test_theano_function_kwargs():
import numpy as np
f = theano_function([x, y, z], [x+y], dim=1, on_unused_input='ignore',
dtypes={x: 'float64', y: 'float64', z: 'float64'})
assert np.linalg.norm(f([1, 2], [3, 4], [0, 0]) - np.asarray([4, 6])) < 1e-9
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So normally would this raise an error?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, because on_unsused_input is not an arg for dim_handling().

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And you can't pass in more values than specified inputs for theano.function.


f = theano_function([x, y, z], [x+y],
dtypes={x: 'float64', y: 'float64', z: 'float64'},
dim=1, on_unused_input='ignore')
xx = np.arange(3).astype('float64')
yy = 2*np.arange(3).astype('float64')
zz = 2*np.arange(3).astype('float64')
assert np.linalg.norm(f(xx, yy, zz) - 3*np.arange(3)) < 1e-9

def test_slice():
assert theano_code(slice(1, 2, 3)) == slice(1, 2, 3)
assert str(theano_code(slice(1, x, 3), dtypes={x: 'int32'})) ==\
Expand Down Expand Up @@ -189,8 +204,8 @@ def test_BlockMatrix_Inverse_execution():
inputs = A, B
output = B.I*A

cutsizes = {A: [(n/2, n/2), (k/2, k/2)],
B: [(n/2, n/2), (n/2, n/2)]}
cutsizes = {A: [(n//2, n//2), (k//2, k//2)],
B: [(n//2, n//2), (n//2, n//2)]}
cutinputs = [sympy.blockcut(i, *cutsizes[i]) for i in inputs]
cutoutput = output.subs(dict(zip(inputs, cutinputs)))

Expand Down Expand Up @@ -221,3 +236,6 @@ def test_AppliedUndef():
ft = theano_code(f(t))
assert isinstance(ft, tt.TensorVariable)
assert ft.name == 'f_t'

def test_bad_keyword_args_raise_error():
raises(Exception, lambda : theano_function([x], [x+1], foobar=3))
12 changes: 10 additions & 2 deletions sympy/printing/theanocode.py
@@ -1,4 +1,5 @@
from __future__ import print_function, division
import inspect
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is inspect part of the standard library? Is it supported by 2.6 and 3.3?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I know. Tests passed on Travis with it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why don't we add a travis test with Theano?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly because I don't know much about travis. If you're knowledgeable enough about this then I strongly encourage it, especially if people are going t start using this.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I'll set travis up for some tests with theano when I have a few minutes.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great idea. I have see this also come up a few time on travis-ci mailing
list, but nobody seam to have worked on that.

On Fri, Aug 9, 2013 at 4:14 PM, Aaron Meurer notifications@github.comwrote:

In sympy/printing/theanocode.py:

@@ -1,4 +1,5 @@
from future import print_function, division
+import inspect

It sounds like such a PPA would be of general interest to the community at
large. Maybe we should bring it up on the NumFocus mailing list.


Reply to this email directly or view it on GitHubhttps://github.com//pull/2358/files#r5692815
.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, a PPA with the latest release SciPy stack compiled for many python versions would be helpful to a ton of other projects.

But this brings up the issue of version of the SciPy stack. What if I want to test against the latest dev version of packages in the SciPy stack. Or what if my software needs to be tested against older versions of the SciPy stack.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe a PPA can contain multiple versions of the same package. See for example the deadsnakes PPA.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They aren't really the same versions of the same package, they have different names. You have to apt-get install python2.6 or python3.2. They are different packages in the deadsnakes for each version of python. So the equivalent for scipy would be apt-get scipy0.7.2 or apt-get numpy1.7.1. So I guess as long as you build them all and have different names it would work.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You'd have to do that anyway for the different python versions.


from sympy.utilities import default_sort_key
from sympy.external import import_module
Expand Down Expand Up @@ -173,7 +174,8 @@ def theano_code(expr, **kwargs):
return TheanoPrinter({}).doprint(expr, **kwargs)


def dim_handling(inputs, dim=None, dims={}, broadcastables={}, keys=()):
def dim_handling(inputs, dim=None, dims={}, broadcastables={}, keys=(),
**kwargs):
""" Handle various input types for dimensions in tensor_wrap

See Also:
Expand All @@ -192,8 +194,14 @@ def dim_handling(inputs, dim=None, dims={}, broadcastables={}, keys=()):
def theano_function(inputs, outputs, dtypes={}, **kwargs):
""" Create Theano function from SymPy expressions """
broadcastables = dim_handling(inputs, **kwargs)

# Remove keyword arguments corresponding to dim_handling
dim_names = inspect.getargspec(dim_handling)[0]
theano_kwargs = dict((k, v) for k, v in kwargs.items()
if k not in dim_names)

code = partial(theano_code, dtypes=dtypes, broadcastables=broadcastables)
tinputs = map(code, inputs)
toutputs = map(code, outputs)
toutputs = toutputs[0] if len(toutputs) == 1 else toutputs
return theano.function(tinputs, toutputs)
return theano.function(tinputs, toutputs, **theano_kwargs)