In [None]:
import functools
import tensorflow as tf
from core.model import ModelMNIST10x10_base
from core.hessian import hessians_highrank

class Implement(ModelMNIST10x10_base):
    def __init__(self, batch_size=128, opt=tf.train.MomentumOptimizer(0.1, 0.0, use_nesterov=False)):
        super(Implement, self).__init__(batch_size, opt)

    def compute_gradient(self, cost, var_refs):
        grads = tf.gradients(
                cost, var_refs,
                grad_ys=None, aggregation_method=None, colocate_gradients_with_ops=True)
        hessis = hessians_highrank(
                 cost, var_refs, gradients=grads,
                 aggregation_method=None, colocate_gradients_with_ops=True)
        
        second_order_grads = []
        for l, g, h in zip(range(len(grads)), grads, hessis):
            shape = g.shape
            d = int(functools.reduce(lambda a,b: a*b, shape))

            g = tf.reshape(g, [d, 1])
            h = tf.reshape(h, [d, d]) + (tf.eye(d) * 1e-1)
            h_inv = tf.matrix_inverse(h)
            delta = tf.matmul(h_inv, g)
            delta = tf.reshape(delta, shape)
#            delta = tf.clip_by_value(delta, -1e1, 1e1)
#            delta = tf.clip_by_norm(delta, 1.0)
            second_order_grads.append(delta)
            
            tf.summary.histogram('update/gradient/{}'.format(l), g)
            tf.summary.histogram('update/hessian/{}'.format(l), h)
            tf.summary.histogram('update/delta/{}'.format(l), delta)            
        return second_order_grads

In [None]:
import sys
import logging
logging.basicConfig(level=logging.DEBUG, format='[%(levelname)s] %(message)s', stream=sys.stdout)

from core.utils import loop
model = Implement(batch_size=128*100)
history = loop(model, 30, './summary/6_2_second_order_optimization_batch_128*100')

[32m[1124 15:35:30 @fs.py:89][0m [5m[31mWRN[0m Env var $TENSORPACK_DATASET not set, using /root/tensorpack_data for datasets.
[32m[1124 15:35:31 @prefetch.py:169][0m [PrefetchData] Will fork a dataflow more than one times. This assumes the datapoints are i.i.d.
[INFO] session initialized


In [None]:
import matplotlib
%matplotlib inline

from core.utils import plot_jupyter
plot_jupyter(history)