Skip to content

Commit

Permalink
Adapt Theano backend to new design
Browse files Browse the repository at this point in the history
Signed-off-by: Niklas Koep <niklas.koep@gmail.com>
  • Loading branch information
nkoep committed Feb 1, 2020
1 parent eac38db commit 775bf75
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 38 deletions.
8 changes: 8 additions & 0 deletions examples/dominant_invariant_subspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,18 @@ def cost(X):
return -tf.tensordot(X, tf.matmul(A, X), axes=2)
elif backend == "Theano":
X = T.matrix()
U = T.matrix()

@pymanopt.function.Theano(X)
def cost(X):
return -T.dot(X.T, T.dot(A, X)).trace()

# Define the Euclidean Hessian-vector product explicitly for the
# purpose of demonstration. The Euclidean gradient is automatically
# calculated via Theano's autodiff capabilities.
@pymanopt.function.Theano(X, U)
def ehess(X, U):
return -T.dot(A + A.T, U)
else:
raise ValueError("Unsupported backend '{:s}'".format(backend))

Expand Down
53 changes: 15 additions & 38 deletions pymanopt/autodiff/backends/_theano.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import functools
import itertools

try:
Expand All @@ -11,7 +10,6 @@

from ._backend import Backend
from .. import make_graph_backend_decorator
from ...tools import flatten_arguments, group_return_values, unpack_arguments


class _TheanoBackend(Backend):
Expand All @@ -26,41 +24,29 @@ def is_available():
def is_compatible(self, function, arguments):
if not isinstance(function, T.TensorVariable):
return False
flattened_arguments = flatten_arguments(arguments)
return all([isinstance(argument, T.TensorVariable)
for argument in flattened_arguments])
for argument in arguments])

def _compile_function_without_warnings(self, *args, **kwargs):
return theano.function(*args, **kwargs, on_unused_input="ignore")

@Backend._assert_backend_available
def compile_function(self, function, arguments):
"""Compiles a Theano graph into a python function."""
flattened_arguments = flatten_arguments(arguments)
compiled_function = self._compile_function_without_warnings(
flattened_arguments, function)
if len(flattened_arguments) == 1:
return compiled_function
return unpack_arguments(compiled_function)
"""Compiles a Theano graph into a callable."""
return self._compile_function_without_warnings(arguments, function)

@Backend._assert_backend_available
def compute_gradient(self, function, arguments):
"""Returns a compiled function computing the gradient of `function`
with respect to 'arguments'.
"""Returns a compiled function computing the gradient of ``function``
with respect to ``arguments``.
"""
flattened_arguments = flatten_arguments(arguments)

if len(flattened_arguments) == 1:
(argument,) = flattened_arguments
if len(arguments) == 1:
(argument,) = arguments
gradient = T.grad(function, argument)
return self._compile_function_without_warnings(
flattened_arguments, gradient)
return self._compile_function_without_warnings(arguments, gradient)

gradient = T.grad(function, flattened_arguments)
compiled_gradient = self._compile_function_without_warnings(
flattened_arguments, gradient)
return group_return_values(
unpack_arguments(compiled_gradient), arguments)
gradient = T.grad(function, arguments)
return self._compile_function_without_warnings(arguments, gradient)

def _compute_unary_hessian_vector_product(self, gradient, argument):
"""Returns a function accepting two arguments to compute a
Expand Down Expand Up @@ -104,28 +90,19 @@ def _compute_nary_hessian_vector_product(self, gradients, arguments):
list(itertools.chain(arguments, argument_types)), Rop)

@Backend._assert_backend_available
def compute_hessian(self, function, arguments):
def compute_hessian_vector_product(self, function, arguments):
"""Computes the directional derivative of the gradient, which is
equivalent to computing a Hessian-vector product with the direction
vector.
"""
flattened_arguments = flatten_arguments(arguments)

if len(flattened_arguments) == 1:
(argument,) = flattened_arguments
if len(arguments) == 1:
(argument,) = arguments
gradient = T.grad(function, argument)
return self._compute_unary_hessian_vector_product(
gradient, argument)

gradients = T.grad(function, flattened_arguments)
hessian_vector_product = self._compute_nary_hessian_vector_product(
gradients, flattened_arguments)

@functools.wraps(hessian_vector_product)
def wrapper(points, vectors):
return hessian_vector_product(*flatten_arguments(points),
*flatten_arguments(vectors))
return group_return_values(wrapper, arguments)
gradients = T.grad(function, arguments)
return self._compute_nary_hessian_vector_product(gradients, arguments)


Theano = make_graph_backend_decorator(_TheanoBackend)

0 comments on commit 775bf75

Please sign in to comment.