In [1]:
import functools
import tensorflow as tf
from core.model import ModelMNIST10x10_base
from core.hessian import diagonal_inverse_hessians_highrank

class Implement(ModelMNIST10x10_base):
    def __init__(self, batch_size=128, opt=tf.train.MomentumOptimizer(0.1, 0.9, use_nesterov=True)):
        super(Implement, self).__init__(batch_size, opt)

    def compute_gradient(self, cost, var_refs):
        grads = tf.gradients(
                cost, var_refs,
                grad_ys=None, aggregation_method=None, colocate_gradients_with_ops=True)
        hessis = diagonal_inverse_hessians_highrank(
                 cost, var_refs, gradients=grads,
                 aggregation_method=None, colocate_gradients_with_ops=True)
        
        second_order_grads = []
        for l, g, h in zip(range(len(grads)), grads, hessis):
            shape = g.shape
            d = int(functools.reduce(lambda a,b: a*b, shape))

            h_inv = tf.reshape(h, shape)
            delta = tf.multiply(h_inv, g)
            second_order_grads.append(delta)
            
            tf.summary.histogram('update/gradient/{}'.format(l), g)
            tf.summary.histogram('update/hessian/{}'.format(l), h)
            tf.summary.histogram('update/delta/{}'.format(l), delta)            
        return second_order_grads

In [2]:
import sys
import logging
logging.basicConfig(level=logging.INFO, format='[%(levelname)s] %(message)s', stream=sys.stdout)

from core.utils import loop
model = Implement(batch_size=128*100)
history = loop(model, 30, './summary/8_2_incomplete_second_order_optimization_batch_128x100')

[32m[1124 14:23:46 @fs.py:89][0m [5m[31mWRN[0m Env var $TENSORPACK_DATASET not set, using /root/tensorpack_data for datasets.
[32m[1124 14:23:46 @prefetch.py:169][0m [PrefetchData] Will fork a dataflow more than one times. This assumes the datapoints are i.i.d.
[INFO] session initialized
[INFO] [0000] [train] cost:0.575 accuracy:0.087 elapsed:1.617sec [valid] cost:0.385 accuracy:0.089 elapsed:0.040sec
[INFO] [0001] [train] cost:0.344 accuracy:0.091 elapsed:0.193sec [valid] cost:0.327 accuracy:0.096 elapsed:0.006sec
[INFO] [0002] [train] cost:0.326 accuracy:0.090 elapsed:0.193sec [valid] cost:0.325 accuracy:0.114 elapsed:0.006sec
[INFO] [0003] [train] cost:0.325 accuracy:0.114 elapsed:0.190sec [valid] cost:0.325 accuracy:0.114 elapsed:0.007sec
[INFO] [0004] [train] cost:0.325 accuracy:0.122 elapsed:0.189sec [valid] cost:0.324 accuracy:0.119 elapsed:0.006sec
[INFO] [0005] [train] cost:0.324 accuracy:0.132 elapsed:0.188sec [valid] cost:0.324 accuracy:0.174 elapsed:0.007sec
[INFO] [

In [None]:
import matplotlib
%matplotlib inline

from core.utils import plot_jupyter
plot_jupyter(history)