<a href="https://colab.research.google.com/github/sandrons/RevisitNNWeightInit/blob/main/Experiments_hessian_init.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Copyright (c) 2024 Alessandro Temperoni

# Approximated Hessian Chain Rule

The hessian backpropagates is more complicated than for gradient. For a function $L(z)$
and parametrization $z=z(w)$ we have

$$
D^2_{w} L = \underbrace{D^2_z L \bullet  D_w z \bullet D_{w} z}_{\text{linearization effect}} + \underbrace{D_w L \bullet D^2_w z}_{\text{curvature effect}}
$$

where $\bullet$ are tensor product on appropriate axes.

We empirically validate (for a theoretical argument see the paper) that for neural networks, under certain conditions, the first term dominates! This helps to very efficiently approximate hessian calculations, e.g. at the initialization.*italicised text*

# Empirical Evaluation

This is an empirical test that confirms the teory and shows the importace for a correct weighs initialization. Moreover, the Hessian matrix calculation allows us to drive the optimization (step size) and control the gain from the gradients during the backprogation phase.

## Different Datasets

In [None]:
mnist = tf.keras.datasets.mnist
(train_inputs, train_labels), _ = mnist.load_data()
train_inputs = train_inputs / 255.0

In [None]:
cifar100 = tf.keras.datasets.cifar100
(train_inputs, train_labels), _ = cifar100.load_data()
train_inputs = train_inputs / 255.0

In [None]:
cifar10 = tf.keras.datasets.cifar10
(train_inputs, train_labels), _ = cifar10.load_data()
train_inputs = train_inputs / 255.0

In [None]:
svhn_cropped = tf.datasets.svhn_cropped
(train_inputs, train_labels), _ = svhn_cropped.load_data()
train_inputs = train_inputs / 255.0

## Build Model

The model is built with Keras A.P.I. and run in Google Colab and it is organized as follows:
 - input labels
 - input data images 28x28

Then we have a Flatten() layer to switch from input data 28x28 to one dimensional 784 input data vector. After we flatten everything we have
 - dense1 layer
 - dense2 layer
 - dense3 layer

Finally we use Categorical Cross-entropy as loss function.

In [None]:
%tensorflow_version 1.x

import tensorflow as tf
from tensorflow.keras import backend as K

def SparseCategoricalCrossentropy(labels,logits):
  Z = tf.reduce_logsumexp(logits,axis=-1)
  lookup_labels = tf.stack([tf.range(tf.shape(labels)[0]),tf.cast(labels,tf.int32)],1)
  true_logits = tf.gather_nd(logits,lookup_labels,batch_dims=0)
  return -true_logits + Z

def build_network(activation='linear',sd=0.01):

  inputs = tf.keras.layers.Input(shape=[28,28],dtype=tf.float32,name='inputs',batch_size=32)
  labels = tf.keras.layers.Input(shape=[],dtype=tf.int32,name='labels',batch_size=32)

  layer1 = tf.keras.layers.Flatten()
  out1 = layer1(inputs)

  layer2 = tf.keras.layers.Dense(30,activation=activation,kernel_initializer=tf.keras.initializers.RandomNormal(stddev=sd),name='dense1')
  out2 = layer2(out1)

  layer3 = tf.keras.layers.Dense(30,activation=activation,name='dense2')
  out3 = layer3(out2)

  layer4 = tf.keras.layers.Dense(10,activation='linear',name='dense3')
  out = layer4(out3)
  model = tf.keras.Model(inputs=[inputs,labels], outputs=out)

  loss = SparseCategoricalCrossentropy(model.input[1],model.output)
  model.add_loss(loss)
  sgd = tf.keras.optimizers.SGD(lr=0.01)
  model.compile(optimizer=sgd)

  return model

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
[1m11490434/11490434[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 0us/step


## Run Experiment

In this section we train the model and do the calculation of the Hessian matrix. Some additional operations are needed to put the H in the right shape. Originally, it would be a 4 dimensional matrix and therefore we will not be able to calculate its norm. Moreover, there is a list of 3 possible standard deviations that will be applied in order to follow epoch by epoch how Hessian matrix and loss function are related to each other.   

In [None]:
import numpy as np

batch_size = 32
sds = [0.00001,0.01,1] # candidate standard deviations to check
n_iter = 1000
outs = []
ends = []
for sd in sds:
  ## Build fresh model
  K.clear_session()
  np.random.seed(1234)
  model = build_network('relu',sd)
  ## choose weights to investigate
  g = model.get_layer('dense1').kernel
  # By default Keras aggregates the loss by sum, but we wanted by mean so we are sure if the Hessian explodes, it is not because of the aggregation over the batches
  loss_agg = tf.reduce_mean(model.total_loss)
  # We need to reshape the hessian according to the size of the layer
  H = tf.hessians(loss_agg,g)[0]
  shape = g.get_shape().as_list()
  H = tf.reshape(tf.squeeze(H),[shape[0]*shape[1],shape[0]*shape[1]])
  H_norm = tf.norm(H)
  # estimate hessians
  batch_sample = np.random.randint(0,len(train_labels),size=[1024])
  batch_inputs, batch_labels = train_inputs[batch_sample], train_labels[batch_sample]
  feed_dict = {model.inputs[0]:batch_inputs,model.inputs[1]:batch_labels}
  sess = K.get_session()
  h_norm_val = sess.run(H_norm,feed_dict)

  # train and estimate loss
  for i_epoch in range(n_iter):
    np.random.seed(1234)
    batch_sample = np.random.randint(0,len(train_labels),size=[batch_size])
    batch_inputs, batch_labels = train_inputs[batch_sample], train_labels[batch_sample]
    feed_dict = {model.inputs[0]:batch_inputs,model.inputs[1]:batch_labels}
    sess = K.get_session()
    loss_val = sess.run(loss_agg,feed_dict)
    #loss_val = sess.run(loss_agg,feed_dict)
    out = (sd,i_epoch,loss_val,h_norm_val)
    outs.append(out)
    # train the model
    model.train_on_batch([batch_inputs,batch_labels])

## Summarize

Pandas library will diplay the 3 different standard deviation we have been using for the experiment. For big sdt, the Hessian explodes and oscillades from very big values like 100 to very small ones. For small sdt, while the hessian is very small, the gradients are very small too and then we get very small gains.

In [None]:
import pandas as pd
from matplotlib import pyplot as plt

d = pd.DataFrame(outs,columns=['sd','iter','loss','hessian'])
d.head()

#d.groupby('sd').mean()

for sd,mask in d.groupby('sd').groups.items():
  #d.loc[mask][['loss']].plot(title='sd=%s'%sd)

  tmp = d.loc[mask]

  fig, ax1 = plt.subplots()
  ax2 = ax1.twinx()

  ax1.plot(tmp['loss'],color='blue',label='loss')
  ax1.legend(loc=0)
  ax2.plot(tmp['hessian'],color='orange',label='hess')
  ax2.legend(loc=0)

  #out2=tmp.plot(x='iter',y='hessian',color='orange',ax=ax2)

  #ax1.legend([out1,out2],['loss','hessian'], loc=0)

  plt.tight_layout()
  plt.title('sd=%s'%sd)
  plt.show()
