In [1]:
import numpy as np
import tensorflow as tf
from matplotlib import pyplot as plt
from tensorflow.python.framework import graph_util
import sopt.benchmarks.ops.tensorflow.flops_registry_custom
from tensorflow.python.ops.gradients_impl import _hessian_vector_product
from sopt.optimizers.tensorflow.curveball import Curveball
from sopt.optimizers.tensorflow.lma import LMA

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
def load_pb(pb):
    with tf.gfile.GFile(pb, "rb") as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def, name='')
        return graph

In [4]:
def tf_y_pred(z, tf_affine_transform):
    return tf.reshape(tf_affine_transform @ tf.reshape(z, [3, -1]), [-1])
def tf_loss(y_pred, tf_y_true):
    return 0.5 * tf.reduce_sum((tf_y_true - y_pred)**2)

In [5]:
z_true = np.random.randn(3,100).astype('float32')

random_mat = np.random.randn(3,3)
random_symmetric_mat = random_mat + random_mat.T
evals, evecs = np.linalg.eig(random_symmetric_mat)
affine_transform = evecs

y_true = affine_transform @ z_true
y_true_flat = y_true.flatten()

z_guess = np.random.randn(300).astype('float32')

In [6]:
# Reference:
# https://stackoverflow.com/questions/45085938/tensorflow-is-there-a-way-to-measure-flops-for-a-model

# See more documentation at 
# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/profiler/g3doc/profile_model_architecture.md



In [7]:
# Forward model
tf.reset_default_graph()
var = tf.get_variable('var', dtype=tf.float32, initializer=z_guess)

tf_y_true = tf.convert_to_tensor(y_true_flat, dtype='float32', name='y_true')
tf_affine_transform = tf.convert_to_tensor(affine_transform, dtype='float32', name='affine_transform')

preds = tf_y_pred(var, tf_affine_transform)
loss_tensor = tf_loss(preds, tf_y_true)    

session = tf.Session()
session.run(tf.global_variables_initializer())

run_meta = tf.RunMetadata()
opts = tf.profiler.ProfileOptionBuilder.float_operation()    
flops_fwd = tf.profiler.profile(run_meta=run_meta, cmd='graph', options=opts) 
print(flops_fwd.total_float_ops)

W0926 14:35:56.681059 140149697488704 deprecation.py:323] From /raid/home/skandel/miniconda3/envs/ad/lib/python3.7/site-packages/tensorflow/python/ops/math_ops.py:2759: tensor_shape_from_node_def_name (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.compat.v1.graph_util.tensor_shape_from_node_def_name`


2700


In [8]:
# Forward model + gradients
tf.reset_default_graph()
var = tf.get_variable('var', dtype=tf.float32, initializer=z_guess)

tf_y_true = tf.convert_to_tensor(y_true_flat, dtype='float32', name='y_true')
tf_affine_transform = tf.convert_to_tensor(affine_transform, dtype='float32', name='affine_transform')

preds = tf_y_pred(var, tf_affine_transform)
loss_tensor = tf_loss(preds, tf_y_true)  

gradients = tf.gradients([loss_tensor], [var])

session = tf.Session()
session.run(tf.global_variables_initializer())

run_meta = tf.RunMetadata()
opts = tf.profiler.ProfileOptionBuilder.float_operation()    
flops_fwd_grad = tf.profiler.profile(run_meta=run_meta, cmd='graph', options=opts) 
print(flops_fwd_grad.total_float_ops)

W0926 14:35:56.758760 140149697488704 deprecation.py:323] From /raid/home/skandel/miniconda3/envs/ad/lib/python3.7/site-packages/tensorflow/python/ops/math_grad.py:1205: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where


9002


In [9]:
# gauss-newton vector product
tf.reset_default_graph()
var = tf.get_variable('var', dtype=tf.float32, initializer=z_guess)
z = tf.get_variable('z', dtype=tf.float32, initializer=tf.zeros_like(z_guess, dtype='float32'))
dummy_var = tf.get_variable('dummy', dtype=tf.float32, initializer=tf.zeros_like(y_true_flat, dtype='float32'))

tf_y_true = tf.convert_to_tensor(y_true_flat, dtype='float32', name='y_true')
tf_affine_transform = tf.convert_to_tensor(affine_transform, dtype='float32', name='affine_transform')

preds = tf_y_pred(var, tf_affine_transform)
loss_tensor = tf_loss(preds, tf_y_true)  

#jloss = tf.gradients(loss_tensor, preds)

vjp_dummy = tf.gradients(preds, var,dummy_var)[0]
jvpz = tf.gradients(vjp_dummy, dummy_var, z)[0]

gvpz = tf.gradients(preds, var, jvpz)[0]

session = tf.Session()
session.run(tf.global_variables_initializer())

run_meta = tf.RunMetadata()
opts = tf.profiler.ProfileOptionBuilder.float_operation()    
flops_gvp = tf.profiler.profile(run_meta=run_meta, cmd='graph', options=opts) 
print(flops_gvp.total_float_ops)

13500


In [10]:
# generalized gauss-newton vector product (with hessian-vector product)
tf.reset_default_graph()
var = tf.get_variable('var', dtype=tf.float32, initializer=z_guess)
z = tf.get_variable('z', dtype=tf.float32, initializer=tf.zeros_like(z_guess, dtype='float32'))
dummy_var = tf.get_variable('dummy', dtype=tf.float32, initializer=tf.zeros_like(y_true_flat, dtype='float32'))

tf_y_true = tf.convert_to_tensor(y_true_flat, dtype='float32', name='y_true')
tf_affine_transform = tf.convert_to_tensor(affine_transform, dtype='float32', name='affine_transform')

preds = tf_y_pred(var, tf_affine_transform)
loss_tensor = tf_loss(preds, tf_y_true)  

#jloss = tf.gradients(loss_tensor, preds)

vjp_dummy = tf.gradients(preds, var,dummy_var)[0]
jvpz = tf.gradients(vjp_dummy, dummy_var, z)[0]

hjvpz = _hessian_vector_product([loss_tensor], [preds], [jvpz])

gvpz = tf.gradients(preds, var, hjvpz)[0]

session = tf.Session()
session.run(tf.global_variables_initializer())

run_meta = tf.RunMetadata()
opts = tf.profiler.ProfileOptionBuilder.float_operation()    
flops_gvp = tf.profiler.profile(run_meta=run_meta, cmd='graph', options=opts) 
print(flops_gvp.total_float_ops)

20703


In [11]:
# curveball (without hessian-vector-product)
tf.reset_default_graph()

var = tf.get_variable('var', dtype=tf.float32, initializer=z_guess)

tf_y_true = tf.convert_to_tensor(y_true_flat, dtype='float32', name='y_true')
tf_affine_transform = tf.convert_to_tensor(affine_transform, dtype='float32', name='affine_transform')

preds_fn = lambda x: tf_y_pred(x, tf_affine_transform)
loss_fn = lambda x: tf_loss(x, tf_y_true)

optimizer = Curveball(var, predictions_fn=preds_fn, loss_fn=loss_fn, squared_loss=True, name='opt')
minimize_op = optimizer.minimize()

session = tf.Session()

session.run(tf.global_variables_initializer())

run_meta = tf.RunMetadata()
opts = tf.profiler.ProfileOptionBuilder.float_operation()    
flops_gvp = tf.profiler.profile(run_meta=run_meta, cmd='graph', options=opts) 
print(flops_gvp.total_float_ops)

W0926 14:35:57.092910 140149697488704 deprecation_wrapper.py:119] From /raid/home/skandel/code/sopt/sopt/optimizers/tensorflow/curveball.py:50: The name tf.variable_scope is deprecated. Please use tf.compat.v1.variable_scope instead.

W0926 14:35:57.101360 140149697488704 deprecation_wrapper.py:119] From /raid/home/skandel/code/sopt/sopt/optimizers/tensorflow/curveball.py:54: The name tf.get_variable is deprecated. Please use tf.compat.v1.get_variable instead.

W0926 14:35:57.210848 140149697488704 deprecation_wrapper.py:119] From /raid/home/skandel/code/sopt/sopt/optimizers/tensorflow/curveball.py:225: The name tf.assign is deprecated. Please use tf.compat.v1.assign instead.



29728


In [12]:
# LMA (without hessian-vector-product)
tf.reset_default_graph()
var = tf.get_variable('var', dtype=tf.float32, initializer=z_guess)

tf_y_true = tf.convert_to_tensor(y_true_flat, dtype='float32', name='y_true')
tf_affine_transform = tf.convert_to_tensor(affine_transform, dtype='float32', name='affine_transform')

preds_fn = lambda x: tf_y_pred(x, tf_affine_transform)
loss_fn = lambda x: tf_loss(x, tf_y_true)

optimizer = LMA(var, predictions_fn=preds_fn, loss_fn=loss_fn, squared_loss=True, 
                name='opt', max_cg_iter=10)
minimize_op = optimizer.minimize()

session = tf.Session()

session.run(tf.global_variables_initializer())

run_meta = tf.RunMetadata()
opts = tf.profiler.ProfileOptionBuilder.float_operation()    
flops_gvp = tf.profiler.profile(run_meta=run_meta, cmd='graph', options=opts) 
print(flops_gvp.total_float_ops)


W0926 14:35:57.441787 140149697488704 lma.py:123] The ftol, gtol, and xtol conditions are adapted from https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.least_squares.html.This is a test version, and there is no guarantee that these work as intended.
W0926 14:35:57.453846 140149697488704 deprecation_wrapper.py:119] From /raid/home/skandel/code/sopt/sopt/optimizers/tensorflow/lma.py:130: The name tf.assert_greater is deprecated. Please use tf.compat.v1.assert_greater instead.



39931
