Skip to content

Commit

Permalink
Refactor TensorFlow backend for new backend logic
Browse files Browse the repository at this point in the history
Signed-off-by: Niklas Koep <niklas.koep@gmail.com>
  • Loading branch information
nkoep committed Feb 1, 2020
1 parent 775bf75 commit 4a28bbf
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 74 deletions.
7 changes: 7 additions & 0 deletions examples/dominant_invariant_subspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,13 @@ def cost(X):
@pymanopt.function.TensorFlow(X)
def cost(X):
return -tf.tensordot(X, tf.matmul(A, X), axes=2)

# Define the Euclidean gradient explicitly for the purpose of
# demonstration. The Euclidean Hessian-vector product is automatically
# calculated via TensorFlow's autodiff capabilities.
@pymanopt.function.TensorFlow(X)
def egrad(X):
return -tf.matmul(A + A.T, X)
elif backend == "Theano":
X = T.matrix()
U = T.matrix()
Expand Down
26 changes: 11 additions & 15 deletions pymanopt/autodiff/backends/_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from ._backend import Backend
from .. import make_tracing_backend_decorator
from ...tools import unpack_singleton_iterable_return_value


class _AutogradBackend(Backend):
Expand All @@ -29,35 +30,30 @@ def is_compatible(self, objective, argument):
def compile_function(self, function, arguments):
return function

def _unpack_return_value(self, function):
@functools.wraps(function)
def wrapper(*args, **kwargs):
return function(*args, **kwargs)[0]
return wrapper

@Backend._assert_backend_available
def compute_gradient(self, function, arguments):
num_arguments = len(arguments)
gradient = autograd.grad(function, argnum=list(range(num_arguments)))
if num_arguments > 1:
return gradient
return self._unpack_return_value(gradient)
if num_arguments == 1:
return unpack_singleton_iterable_return_value(gradient)
return gradient

@Backend._assert_backend_available
def compute_hessian_vector_product(self, function, arguments):
num_arguments = len(arguments)
hessian_vector_product = autograd.hessian_vector_product(
function, argnum=tuple(range(num_arguments)))
if num_arguments == 1:
return self._unpack_return_value(hessian_vector_product)
return unpack_singleton_iterable_return_value(
hessian_vector_product)

@functools.wraps(hessian_vector_product)
def wrapper(*arguments):
num_arguments = len(arguments)
def wrapper(*args):
num_arguments = len(args)
assert num_arguments % 2 == 0
point = arguments[:num_arguments // 2]
vector = arguments[num_arguments // 2:]
return hessian_vector_product(*point, vector)
arguments = args[:num_arguments // 2]
vectors = args[num_arguments // 2:]
return hessian_vector_product(*arguments, vectors)
return wrapper


Expand Down
92 changes: 33 additions & 59 deletions pymanopt/autodiff/backends/_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from ._backend import Backend
from .. import make_graph_backend_decorator
from ...tools import flatten_arguments, group_return_values
from ...tools import unpack_singleton_iterable_return_value


class _TensorFlowBackend(Backend):
Expand All @@ -36,56 +36,37 @@ def is_available():
def is_compatible(self, function, arguments):
if not isinstance(function, tf.Tensor):
return False
flattened_arguments = flatten_arguments(arguments)
return all([isinstance(argument, tf.Variable)
for argument in flattened_arguments])
for argument in arguments])

@Backend._assert_backend_available
def compile_function(self, function, arguments):
flattened_arguments = flatten_arguments(arguments)
if len(flattened_arguments) == 1:
def unary_function(point):
(argument,) = flattened_arguments
feed_dict = {argument: point}
return self._session.run(function, feed_dict)
return unary_function

def nary_function(arguments):
flattened_inputs = flatten_arguments(arguments)
def compile_function(self, function, variables):
def compiled_function(*args):
feed_dict = {
argument: array
for argument, array in zip(flattened_arguments,
flattened_inputs)
variable: argument
for variable, argument in zip(variables, args)
}
return self._session.run(function, feed_dict)
return nary_function
return compiled_function

@staticmethod
def _gradients(function, arguments):
return tf.gradients(function, arguments,
unconnected_gradients=tf.UnconnectedGradients.ZERO)

@Backend._assert_backend_available
def compute_gradient(self, function, arguments):
flattened_arguments = flatten_arguments(arguments)
gradient = self._gradients(function, flattened_arguments)
def compute_gradient(self, function, variables):
gradients = self._gradients(function, variables)

if len(flattened_arguments) == 1:
(argument,) = flattened_arguments

def unary_gradient(point):
feed_dict = {argument: point}
return self._session.run(gradient[0], feed_dict)
return unary_gradient

def nary_gradient(points):
def gradient(*args):
feed_dict = {
argument: point
for argument, point in zip(flattened_arguments,
flatten_arguments(points))
variable: argument
for variable, argument in zip(variables, args)
}
return self._session.run(gradient, feed_dict)
return group_return_values(nary_gradient, arguments)
return self._session.run(gradients, feed_dict)
if len(variables) == 1:
return unpack_singleton_iterable_return_value(gradient)
return gradient

@staticmethod
def _hessian_vector_product(function, arguments, vectors):
Expand Down Expand Up @@ -122,33 +103,26 @@ def _hessian_vector_product(function, arguments, vectors):
return _TensorFlowBackend._gradients(element_wise_products, arguments)

@Backend._assert_backend_available
def compute_hessian(self, function, arguments):
flattened_arguments = flatten_arguments(arguments)

if len(flattened_arguments) == 1:
(argument,) = flattened_arguments
zeros = tf.zeros_like(argument)
hessian = self._hessian_vector_product(
function, [argument], [zeros])

def unary_hessian(point, vector):
feed_dict = {argument: point, zeros: vector}
return self._session.run(hessian[0], feed_dict)
return unary_hessian

zeros = [tf.zeros_like(argument) for argument in flattened_arguments]
hessian = self._hessian_vector_product(
function, flattened_arguments, zeros)

def nary_hessian(points, vectors):
def compute_hessian_vector_product(self, function, variables):
zeros = [tf.zeros_like(variable) for variable in variables]
hessian = self._hessian_vector_product(function, variables, zeros)

def hessian_vector_product(*args):
num_arguments = len(args)
assert num_arguments % 2 == 0
arguments = args[:num_arguments // 2]
vectors = args[num_arguments // 2:]
feed_dict = {
argument: value for argument, value in zip(
itertools.chain(flattened_arguments, zeros),
itertools.chain(flatten_arguments(points),
flatten_arguments(vectors)))
variable: argument
for variable, argument in zip(
itertools.chain(variables, zeros),
itertools.chain(arguments, vectors))
}
return self._session.run(hessian, feed_dict)
return group_return_values(nary_hessian, arguments)
if len(variables) == 1:
return unpack_singleton_iterable_return_value(
hessian_vector_product)
return hessian_vector_product


TensorFlow = make_graph_backend_decorator(_TensorFlowBackend)
11 changes: 11 additions & 0 deletions pymanopt/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,14 @@ def inner(*args):
i += n
return groups
return inner


def unpack_singleton_iterable_return_value(function):
"""Function decorator which unwraps^
"""
@functools.wraps(function)
def wrapper(*args):
result = function(*args)
assert isinstance(result, (list, tuple))
return result[0]
return wrapper

0 comments on commit 4a28bbf

Please sign in to comment.