<a href="https://colab.research.google.com/github/pnorridge/think-global-act-local/blob/master/MNIST_%26_SNR_Example.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This notebook is a supplement to the paper *Think Global, Act Local: The relationship between DNN generalisation and node-level information preservation.*

Abstract: We argue that the (global) generalisation performance of a DNN is related to the information preservation and Signal-to-Noise Ratio of individual nodes. Further, some weight combinations generate better a SNR than others. We demonstrate this by deriving figures-of-merit that can be applied to weight sets and give examples of the correlation between these figures-of-merit with DNN generalisation.

This notebook gives an example of the process applied to MNIST classification.

The underlying reasoning for considering node-level SNR and SNR-optimising weights can be summarised as:

* A DNN will perform best when the available information is used to maximum extent.

* Optimal use of information by the entire network depends on maximising the information preservation of individual nodes.

* Linsker (R._Linsker 1988) has shown the relationship between the SNR of a node and the maximisation of the information preservation from inputs to output (or, equivalently, the information rate out at the output of the node.) So, if we want to optimise the information flow in the network, we should pay attention to the SNR of the individual nodes.

* It is possible to quantify how well a given set of weights optimises the SNR within the context of a given network and training set. 

# Admin

In [0]:
# Admin

import tensorflow as tf
tf.enable_eager_execution()

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
sns.set()


In [0]:
# Load in the test data & configure an iterator
mnist = tf.keras.datasets.mnist
(train_images, train_labels),(test_images, test_labels) = mnist.load_data()
train_images, test_images = tf.cast(train_images, tf.float32) / 255.0, tf.cast(test_images, tf.float32) / 255.0
train_labels, test_labels = tf.cast(train_labels, tf.int32), tf.cast(test_labels, tf.int32)

dataset = tf.data.Dataset.from_tensor_slices((train_images, train_labels))
batched_data = dataset.shuffle(buffer_size=60000).batch(50).repeat()
data_iterator = batched_data.make_one_shot_iterator()

# Set up a zero vector -- used as reference points when doing correlations
x_zero = tf.zeros([10,28,28])
corr_test_images = tf.concat([tf.cast(test_images,tf.float32),x_zero], axis=0)

In [0]:
# Essential model-independent helper functions
def weight_variable(shape):
  initial = tf.Variable(tf.truncated_normal(shape, stddev=0.1))
  return initial

def bias_variable(shape):
  initial = tf.constant(0.1, shape=shape)
  return tf.Variable(initial)
    
def measure_accuracy(labels, logits):
  correct_prediction = tf.equal(tf.cast(labels, tf.int64), tf.argmax(logits, axis=1))
  return tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

def cross_entropy(x, y):
  return tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y, logits=x))


def dot_layer(f, wts):  
  def layer(x,p): 
    return tf.matmul(f(x,p),wts)
  
  return layer


def relu_layer(f, bs):
  def layer(x,p):
    return tf.nn.dropout(tf.nn.relu(f(x,p)+bs),p)
  
  return layer


def masked_mean(measure, weight = None):
  if weight is None:
    weight = tf.ones(measure.shape)
  
  mean = tf.reduce_mean(tf.boolean_mask(measure,tf.logical_not(tf.is_nan(measure))))
  return mean

# Building blocks for figures-of-merit

We start by defining a modified covariance function. This calculates the covariance between a node input and the output -- and scales by the non-zero rate of the input.

\begin{equation}
 c_{i}= \frac{\text{cov} \left( x_i,\sum w_{ij} x_j\right)}{a_i} = \frac{\text{cov} \left( x_i,r_i\right)}{a_i} 
\end{equation}

$ a_i = p(x_i > 0) $


In [0]:
def relu_covariance_batch(x,r):
  # This calculates the covariance, including the scaling factor appropriate for ReLU-like inputs.
  # Implemented with batching due to memory issues with large inputs.

  batch_sz = 50
  
  leng = tf.cast(x.shape[0],tf.float32)

  m = int(leng/batch_sz-1)

  # Normalise the inputs. This is a quick and dirty way to get the noise variances to approximately the same magnitude.
  max_count = tf.cast(x.shape[0], tf.float32)
  x = x/(tf.reduce_max(x, axis = 0)+0.001)
  
  xx = tf.cast(tf.expand_dims(x,-1),tf.float32)
  rr = tf.expand_dims(r,1)
  
  # Shift inputs to be zero mean
  xxm = tf.reduce_sum(xx[batch_sz*m:], axis = 0)
  rrm = tf.reduce_sum(rr[batch_sz*m:], axis = 0)
  for k in range(0,m):
    xxm = xxm + tf.reduce_sum(xx[batch_sz*k:batch_sz*(k+1)], axis = 0)
    rrm = rrm + tf.reduce_sum(rr[batch_sz*k:batch_sz*(k+1)], axis = 0)
  
  xx = xx - xxm/leng #(x_count+0.0001) # lazily ensure no 1/0
  rr = rr - rrm/tf.cast(rr.shape[0],tf.float32)
  
  # Calculate the variance
  xr = tf.reduce_sum(xx[batch_sz*m:,:,:]*rr[batch_sz*m:,:,:], axis = 0)
  for k in range(0,m):
    xr = xr + tf.reduce_sum(xx[batch_sz*k:batch_sz*(k+1),:,:]*rr[batch_sz*k:batch_sz*(k+1),:,:], axis = 0)

  # Scaling for ReLU inputs 
  x_count = tf.reduce_sum(tf.cast(xx>0, tf.float32), axis = 0) 
  xr = xr/(x_count+0.0001) # lazily ensure no 1/0

  return xr


To derive the SNR Optimiality expression, we start with the usual (signal processing) definition of SNR applied to node $i$ in layer $m$

\begin{equation}
SNR_i^{(m)} =\frac{\text{var} \left(\sum_{j}w_{ij}s_{j} \right)}{ \sum_{j}\text{var} \left( w_{ij}n_{j} \right) } 
\end{equation}

where $s_j$ are the signal and $n_j$ are the noise components. At this point we leave open how to partition inputs into these components. In all cases, variance and covariances are calculated over a batch. We look for weights that maximise the expression.  

\begin{equation}
w_{ij}= k_i .  \frac{\text{cov} \left( s_j, \sum w_{ij} s_i \right) }{\text{var} \left( n_j \right)}  
\end{equation}

Where  $ k_i $  is a constant independent of  \( j \)  and we have implicitly assumed that  $ \text{var} \left( n_{j} \right)\neq 0 $ .\par

To allow us to use this, we make two pragmatic assumptions. 
The first is that the after training, the signal dominates the node output, so that 
\begin{equation}
\text{cov} \left( s_i, \sum w_{ij} s_j \right) \approx \text{cov} \left( x_j, \sum w_{ij} x_i \right)
\end{equation}

The second is that the noise has identical variance for each non-zero input sample and zero variance otherwise. With these assumptions, we make the approximation
\begin{equation}
 \text{var} \left( n_{j} \right) \approx a_j ~ \text{var}(n)
\end{equation}

where $\text{var}(n)$ is the common noise variance of all samples arriving at node $i$ and $a_i$ is the rate of activations of node $j$ of layer $(m-1)$. 

Using these assumptions, we update the expression for optimal weights to
\begin{equation}
w_{ij} = k_i' .  \frac{\text{cov} \left( x_i, \sum w_{ij} x_j \right) }{ a_j }  
\end{equation}


We do not expect that nodes will generally meet this condition, but we would like to assess how close a given weight configuration is to 'optimal'. We can measure this by treating $w_{ij}$ and $\frac{\text{cov} \left( x_i, \sum w_{ij} x_j \right) }{\left( a_j \right)}$ as vectors and calculating the inner product between them\footnote{At this point, we note that there is a close relationship between this condition and PCA. When $a_j = 1 \ \forall j$, the weight will be optimal if it is an eigenvector of the covariance matrix.}.

\begin{equation}
S_i =\frac{ \sum _{j}w_{ij}c_{j}}{ \vert \vec{w_{i}} \vert  \vert \vec{c} \vert } 
\end{equation}

where $\vec{c}$ is as defined above.

In [0]:

def SNROptimality_batch(x, r, W):
  
    xR = relu_covariance_batch(x,r)
    
    nm = (tf.norm(xR,axis=0)+0.0001)
    c_hat = xR/nm

    
    count = tf.reduce_sum(tf.cast(c_hat*W<0, tf.float32))
    numer = tf.reduce_sum(c_hat*W, axis = 0)
  
    norm_wtv = tf.norm(W, axis=0)

    denom = norm_wtv

    return (numer/denom, count)
  


To measure SNR improvement, we compare the SNR Optimiality for a given weight vector with the case where we simply selected the most representative input and passed that through transparently. That is, we calculating $S_i$ using only the input that is best correlated with the output of the node. We fix $c_j$ and then use $w_{ij}'$ defined by
\begin{equation}
    w_{ij}' =  \begin{cases}
                1 \ \ \  \text{if }  j=\text{argmax}_k \left( c_k \right)   \\
               0 \ \ \  \text{otherwise}
            \end{cases}
\end{equation}

For this choice,
\begin{equation} 
S_i' = \frac{ \sum _{j}^{}w_{ij}'c_{j}}{ \vert \vec{w'_i} \vert  \vert \vec{c} \vert }  =  \frac{\max_{i} \left( c_{i} \right) }{ \vert \vec{c} \vert } 
\end{equation}

The ratio of the $S_i$ and $S_i'$ becomes
\begin{equation}
F=\frac{ \sum _{j}^{}w_{ij}c_{j}}{ \vert \vec{w_i} \vert \max_{i} \left( c_{i} \right) } 
\end{equation}


In [0]:
def SignalFactor_batch(x, r, W):
  
    xR = relu_covariance_batch(x,r)
    
    nm = (tf.reduce_max(tf.abs(xR),axis=0,keepdims=True))
    c_hat = xR/nm

    
    count = tf.reduce_sum(tf.cast(c_hat*W<0, tf.float32))
    numer = tf.reduce_sum(c_hat*W, axis = 0)
  
    norm_wtv = tf.norm(W, axis=0)

    denom = norm_wtv

    return (numer/denom, count)
  

The following will be used later to aid with visualisation of the SNR optimality

In [0]:
def Optimality_Plot(x, r, W, node_list):
  
    xR = relu_covariance_batch(x,r)
        
    W1 = W[:,node_list]
 
    c_hat = xR/(tf.norm(xR, axis=0))

    norm_wtv = tf.norm(W1, axis=0)

    normed_wt = W1/norm_wtv

    fig,ax = plt.subplots(figsize=(5,5))

    ax.plot(normed_wt,c_hat[:,node_list],'.')
    ax.fill_between([0, tf.reduce_max(normed_wt).numpy()], 0,tf.reduce_max(c_hat[:,node_list]).numpy(),alpha=0.2, color='#1F98D0')  
    ax.fill_between([tf.reduce_min(normed_wt).numpy(), 0], tf.reduce_min(c_hat[:,node_list]).numpy(), 0, alpha=0.2, color='#1F98D0')  

    plt.show()

<h1>Model definition </h1>

Define our base class. The costruction is realtively low-level to allow easy access to all the information needed for the figures-of-merit.

In [0]:

class trainable_function:

  def __init__(self, withWeightsFrom = None):
    self.W = []
    self.b = []
    # Set up the weights
    if withWeightsFrom == None:
      self.W.append(weight_variable([784, 1024]))
      self.b.append(bias_variable([1024]))
    
      self.W.append(weight_variable([1024, 1000]))
      self.b.append(bias_variable([1000]))

      self.W.append(weight_variable([1000, 10]))
      self.b.append(bias_variable([10]))
      
    else:
      for k in range(len(withWeightsFrom.W)):
        self.W.append(tf.Variable(withWeightsFrom.W[k]))
        self.b.append(tf.Variable(withWeightsFrom.b[k]))

    self.variables = [self.W[2], self.b[2], self.W[1], self.b[1], self.W[0], self.b[0]]
    
    # Build the network with lambda functions.
    self.layer_dot = [lambda x, p: x]
    self.layer_out = [lambda x, p: tf.reshape(x,[-1,784])]
    
    for kk in range(1,len(self.W)):
      self.layer_dot.append(dot_layer(self.layer_out[kk-1], self.W[kk-1])) 
      self.layer_out.append(relu_layer(self.layer_dot[kk], self.b[kk-1]))

    self.layer_dot.append(dot_layer(self.layer_out[-1], self.W[-1])) 

    self.optimiser = tf.train.GradientDescentOptimizer(0.1)


  def classify(self, x, p = 1.):
    return self.layer_dot[-1](x,p)


  def train(self, x, y):
    null

  # Calculate F for one layer
  def SignalFactor(self, x, layer):
    x_int = self.layer_out[layer-1](x,1.)
    r = tf.matmul(x_int, self.W[layer-1]) 

    return SignalFactor_batch(x_int, r, self.W[layer-1])


  # Calculate mean(F) for each layer
  def SF_summary(self, x):
    res = []
    for k in range(len(self.layer_out)):
      (measure, count)  = self.SignalFactor(corr_test_images, k+1)
      res.append(masked_mean(measure))
    return res


  # Calculate S for one layer
  def SNROptimality(self, x, layer):
    x_int = self.layer_out[layer-1](x,1.)
    r = tf.matmul(x_int, self.W[layer-1]) 

    return SNROptimality_batch(x_int, r, self.W[layer-1])
  
  # Calculate mean(S) for each layer
  def SNR_summary(self, x):
    res = []
    for k in range(len(self.layer_out)):
      (measure, count)  = self.SNROptimality(corr_test_images, k+1)
      res.append(masked_mean(measure))
    return res


  def Optimality_plot(self, x, layer, nodes):
    x_int = self.layer_out[layer-1](x,1.)
    r = tf.matmul(x_int, self.W[layer-1]) 
    Optimality_Plot(x_int, r, self.W[layer-1], nodes)



Sub-classes representing different training methods/regularisation.

In [0]:


# Basic SGD with no regulsatisation
class trainable_function_with_noreg(trainable_function):

  def calc_grad(self, inputs, targets):
    with tf.GradientTape() as tape:
      loss_value = cross_entropy(self.classify(inputs, p = 1.), targets)
    return tape.gradient(loss_value, self.variables)

  def train(self, x,y):
    grads = self.calc_grad(x, y)
    self.optimiser.apply_gradients(zip(grads, self.variables))

# Basic SGD with dropout
class trainable_function_with_dropout(trainable_function):

  def calc_grad(self, inputs, targets):
    with tf.GradientTape() as tape:
      loss_value = cross_entropy(self.classify(inputs, p = 0.5), targets)
    return tape.gradient(loss_value, self.variables)

  def train(self, x,y):
    grads = self.calc_grad(x, y)
    self.optimiser.apply_gradients(zip(grads, self.variables))
    

# Basic SGD with L2 regulsatisation
class trainable_function_with_L2(trainable_function):

  def calc_grad(self, inputs, targets):
    with tf.GradientTape() as tape:
      loss_value = cross_entropy(self.classify(inputs, p = 1.), targets) \
                                 +0.0001*tf.nn.l2_loss(self.W[0])\
                                 +0.0001*tf.nn.l2_loss(self.W[1])\
                                 +0.0001*tf.nn.l2_loss(self.W[2])
                                 
    return tape.gradient(loss_value, self.variables)

  def train(self, x,y):
    grads = self.calc_grad(x, y)
    self.optimiser.apply_gradients(zip(grads, self.variables))



<h1>Training & evaluation</h1>

Set up three models for comparison. We use identical starting weights and training data, to ensure that there is a comparison between training methods, not simply the initialisation and data.


In [0]:
control = trainable_function_with_noreg()
model_n = trainable_function_with_noreg(control)
model_d = trainable_function_with_dropout(control)
model_l2 = trainable_function_with_L2(control)

model_list = [model_n, model_d, model_l2]

test_acc = {lst: [] for lst in model_list} 
SNR = {lst: [] for lst in model_list} 
SF = {lst: [] for lst in model_list} 


Training with regular calculations of the figures-of-merit. A good time for a coffee.




In [0]:
# Main loop
for i in range(100001):

  # Get training data
  (x,y) = data_iterator.get_next()

  # Loop over models
  for model in model_list:

      # Instrumentation

      if i % 200 == 0:
        test_accuracy  = measure_accuracy(labels = test_labels, logits = model.classify(test_images))
        test_acc[model].append(test_accuracy)


      if i % 2000 == 0:
        SF[model].append(model.SF_summary(corr_test_images))
        SNR[model].append(model.SNR_summary(corr_test_images))
        

      if i % 20000 == 0:
        print('\n Step %d' % i)
        print(type(model).__name__)
        print('Test accuracy %g' % test_accuracy)
        print('L1: %g' % SNR[model][-1][0])
        print('L2: %g' % SNR[model][-1][1])
        print('L3: %g' % SNR[model][-1][2])


      # Actual training          
      model.train(x, y)

  if i % 20000 == 0:
    print('------\n')





 Step 0
trainable_function_with_noreg
Test accuracy 0.1003
L1: 0.182793
L2: 0.157043
L3: 0.186966

 Step 0
trainable_function_with_dropout
Test accuracy 0.1003
L1: 0.182793
L2: 0.157043
L3: 0.186966

 Step 0
trainable_function_with_L2
Test accuracy 0.1003
L1: 0.182793
L2: 0.157043
L3: 0.186966
------



# Test accuracy for each model 

In [0]:
for m in model_list:
  plt.plot(test_acc[m])
plt.ylim([0.97,0.99])
plt.ylabel('test accuracy')
plt.xlabel('iterations (200x)')

# SNR optimality

Note that the SNR optimality is correlated with performance after convergence. Interestingly, for dropout the SNR optimality indicates the enhanced performance before this is seen in the test accuracy results.

In [0]:
fig, a = plt.subplots(1,3,figsize=(15, 5))
for m in model_list:
  a[0].plot([SNR[m][a][0] for a in range(len(SNR[m]))])
  a[1].plot([SNR[m][a][1] for a in range(len(SNR[m]))])
  a[2].plot([SNR[m][a][2] for a in range(len(SNR[m]))])

# Signal Figure

In [0]:
fig, a = plt.subplots(1,3,figsize=(15, 5))
for m in model_list:
  a[0].plot([SF[m][a][0] for a in range(len(SF[m]))])
  a[1].plot([SF[m][a][1] for a in range(len(SF[m]))])
  a[2].plot([SF[m][a][2] for a in range(len(SF[m]))])

# Scatter plots of $W$ vs. $c$

Heere, we plot $W$ against $c$ for one node. This gives a more intuitive & visual sense of what it means to have good SNR optimality. The points become increasingly close to the line $W=c$.

In [0]:
for m in control+model_list:
  ind = tf.argmax(m.SNROptimality(corr_test_images,3)[0])
  m.Optimality_plot(corr_test_images,3,ind)

In [0]:
m.SNROptimality(corr_test_images,2)[0]

In [0]:
ind