Skip to content

Commit

Permalink
Merge pull request #2358 from moorepants/unused-arg
Browse files Browse the repository at this point in the history
Allow passing of kwargs to theano.function.
  • Loading branch information
mrocklin committed Aug 10, 2013
2 parents 196fdb3 + 6ab3bc0 commit c422e97
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 6 deletions.
26 changes: 22 additions & 4 deletions sympy/printing/tests/test_theanocode.py
@@ -1,4 +1,5 @@
from sympy.external import import_module
from sympy.utilities.pytest import raises

theano = import_module('theano')
if theano:
Expand Down Expand Up @@ -135,10 +136,10 @@ def test_theano_function_simple():
f = theano_function([x, y], [x+y])
assert f(2, 3) == 5


def test_theano_function_numpy():
import numpy as np
f = theano_function([x, y], [x+y], dim=1)
f = theano_function([x, y], [x+y], dim=1,
dtypes={x: 'float64', y: 'float64'})
assert np.linalg.norm(f([1, 2], [3, 4]) - np.asarray([4, 6])) < 1e-9

f = theano_function([x, y], [x+y], dtypes={x: 'float64', y: 'float64'},
Expand All @@ -147,6 +148,20 @@ def test_theano_function_numpy():
yy = 2*np.arange(3).astype('float64')
assert np.linalg.norm(f(xx, yy) - 3*np.arange(3)) < 1e-9

def test_theano_function_kwargs():
import numpy as np
f = theano_function([x, y, z], [x+y], dim=1, on_unused_input='ignore',
dtypes={x: 'float64', y: 'float64', z: 'float64'})
assert np.linalg.norm(f([1, 2], [3, 4], [0, 0]) - np.asarray([4, 6])) < 1e-9

f = theano_function([x, y, z], [x+y],
dtypes={x: 'float64', y: 'float64', z: 'float64'},
dim=1, on_unused_input='ignore')
xx = np.arange(3).astype('float64')
yy = 2*np.arange(3).astype('float64')
zz = 2*np.arange(3).astype('float64')
assert np.linalg.norm(f(xx, yy, zz) - 3*np.arange(3)) < 1e-9

def test_slice():
assert theano_code(slice(1, 2, 3)) == slice(1, 2, 3)
assert str(theano_code(slice(1, x, 3), dtypes={x: 'int32'})) ==\
Expand Down Expand Up @@ -189,8 +204,8 @@ def test_BlockMatrix_Inverse_execution():
inputs = A, B
output = B.I*A

cutsizes = {A: [(n/2, n/2), (k/2, k/2)],
B: [(n/2, n/2), (n/2, n/2)]}
cutsizes = {A: [(n//2, n//2), (k//2, k//2)],
B: [(n//2, n//2), (n//2, n//2)]}
cutinputs = [sympy.blockcut(i, *cutsizes[i]) for i in inputs]
cutoutput = output.subs(dict(zip(inputs, cutinputs)))

Expand Down Expand Up @@ -221,3 +236,6 @@ def test_AppliedUndef():
ft = theano_code(f(t))
assert isinstance(ft, tt.TensorVariable)
assert ft.name == 'f_t'

def test_bad_keyword_args_raise_error():
raises(Exception, lambda : theano_function([x], [x+1], foobar=3))
12 changes: 10 additions & 2 deletions sympy/printing/theanocode.py
@@ -1,4 +1,5 @@
from __future__ import print_function, division
import inspect

from sympy.utilities import default_sort_key
from sympy.external import import_module
Expand Down Expand Up @@ -173,7 +174,8 @@ def theano_code(expr, **kwargs):
return TheanoPrinter({}).doprint(expr, **kwargs)


def dim_handling(inputs, dim=None, dims={}, broadcastables={}, keys=()):
def dim_handling(inputs, dim=None, dims={}, broadcastables={}, keys=(),
**kwargs):
""" Handle various input types for dimensions in tensor_wrap
See Also:
Expand All @@ -192,8 +194,14 @@ def dim_handling(inputs, dim=None, dims={}, broadcastables={}, keys=()):
def theano_function(inputs, outputs, dtypes={}, **kwargs):
""" Create Theano function from SymPy expressions """
broadcastables = dim_handling(inputs, **kwargs)

# Remove keyword arguments corresponding to dim_handling
dim_names = inspect.getargspec(dim_handling)[0]
theano_kwargs = dict((k, v) for k, v in kwargs.items()
if k not in dim_names)

code = partial(theano_code, dtypes=dtypes, broadcastables=broadcastables)
tinputs = map(code, inputs)
toutputs = map(code, outputs)
toutputs = toutputs[0] if len(toutputs) == 1 else toutputs
return theano.function(tinputs, toutputs)
return theano.function(tinputs, toutputs, **theano_kwargs)

0 comments on commit c422e97

Please sign in to comment.