<a href="https://colab.research.google.com/github/phrasenmaeher/cka/blob/main/do_nns_learn_the_same%3F.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Do Wide and Deep Neural Networks Learn the Same Things?

Paper is: 

[Do Wide and Deep Networks Learn the Same Things? Uncovering How Neural Network Representations Vary with Width and Depth](https://arxiv.org/abs/2010.15327)

by

Nguyen, Thao and Raghu, Maithra and Kornblith, Simon

### Preliminary code

In [None]:
import numpy as np
import tqdm

In [None]:
def get_strategy(xla=0, fp16=0, no_cuda=0):
  '''
  Determines the strategy under which the network is trained.
  
  From https://github.com/huggingface/transformers/blob/8eb7f26d5d9ce42eb88be6f0150b22a41d76a93d/src/transformers/training_args_tf.py
  
  returns the strategy object
  
  '''
  print("TensorFlow: setting up strategy")

  if xla:
    tf.config.optimizer.set_jit(True)

  gpus = tf.config.list_physical_devices("GPU")
    # Set to float16 at first
  if fp16:
    policy = tf.keras.mixed_precision.experimental.Policy("mixed_float16")
    tf.keras.mixed_precision.experimental.set_policy(policy)

  if no_cuda:
    strategy = tf.distribute.OneDeviceStrategy(device="/cpu:0")
  else:
    try:
      tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    except ValueError:
      tpu = None
  
    if tpu:
    # Set to bfloat16 in case of TPU
      if fp16:
        policy = tf.keras.mixed_precision.experimental.Policy("mixed_bfloat16")
        tf.keras.mixed_precision.experimental.set_policy(policy)
      tf.config.experimental_connect_to_cluster(tpu)
      tf.tpu.experimental.initialize_tpu_system(tpu)
    
      strategy = tf.distribute.experimental.TPUStrategy(tpu)
    
    elif len(gpus) == 0:
        strategy = tf.distribute.OneDeviceStrategy(device="/cpu:0")
    elif len(gpus) == 1:
      strategy = tf.distribute.OneDeviceStrategy(device="/gpu:0")
    elif len(gpus) > 1:
      # If you only want to use a specific subset of GPUs use `CUDA_VISIBLE_DEVICES=0`
      strategy = tf.distribute.MirroredStrategy()
    else:
      raise ValueError("Cannot find the proper strategy! Please check your environment properties.")

  print(f"Using strategy: {strategy}")
  return strategy

Let's first code the functions that implement all the equations we need.

#### HSIC

In [None]:
def unbiased_HSIC(K, L):
  '''Computes an unbiased estimator of HISC. This is equation (2) from the paper'''

  #create the unit **vector** filled with ones
  n = K.shape[0]
  ones = np.ones(shape=(n))

  #fill the diagonal entries with zeros 
  np.fill_diagonal(K, val=0) #this is now K_tilde 
  np.fill_diagonal(L, val=0) #this is now L_tilde

  #first part in the square brackets
  trace = np.trace(np.dot(K, L))

  #middle part in the square brackets
  nominator1 = np.dot(np.dot(ones.T, K), ones)
  nominator2 = np.dot(np.dot(ones.T, L), ones)
  denominator = (n-1)*(n-2)
  middle = np.dot(nominator1, nominator2) / denominator
  
  
  #third part in the square brackets
  multiplier1 = 2/(n-2)
  multiplier2 = np.dot(np.dot(ones.T, K), np.dot(L, ones))
  last = multiplier1 * multiplier2

  #complete equation
  unbiased_hsic = 1/(n*(n-3)) * (trace + middle - last)

  return unbiased_hsic


#### CKA

In [None]:
def CKA(X, Y):
  '''Computes the CKA of two matrices. This is equation (1) from the paper'''
  
  nominator = unbiased_HSIC(np.dot(X, X.T), np.dot(Y, Y.T))
  denominator1 = unbiased_HSIC(np.dot(X, X.T), np.dot(X, X.T))
  denominator2 = unbiased_HSIC(np.dot(Y, Y.T), np.dot(Y, Y.T))

  cka = nominator/np.sqrt(denominator1*denominator2)

  return cka


## Creating and Training the networks

#### Imports

In [None]:
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np

In [None]:
cifar10 = tf.keras.datasets.cifar10

(x_train, y_train), (x_test, y_test) = cifar10.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0 #scale the data

Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz


#### Setting up training strategy

In [None]:
strategy = get_strategy()

TensorFlow: setting up strategy
INFO:tensorflow:Initializing the TPU system: grpc://10.51.171.130:8470


INFO:tensorflow:Initializing the TPU system: grpc://10.51.171.130:8470


INFO:tensorflow:Clearing out eager caches


INFO:tensorflow:Clearing out eager caches


INFO:tensorflow:Finished initializing TPU system.


INFO:tensorflow:Finished initializing TPU system.


INFO:tensorflow:Found TPU system:


INFO:tensorflow:Found TPU system:


INFO:tensorflow:*** Num TPU Cores: 8


INFO:tensorflow:*** Num TPU Cores: 8


INFO:tensorflow:*** Num TPU Workers: 1


INFO:tensorflow:*** Num TPU Workers: 1


INFO:tensorflow:*** Num TPU Cores Per Worker: 8


INFO:tensorflow:*** Num TPU Cores Per Worker: 8


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)


Using strategy: <tensorflow.python.distribute.tpu_strategy.TPUStrategy object at 0x7f74efa22d50>


#### Helper functions to create ResNets

In [None]:
def create_resnet50():
  
  resnet_base = tf.keras.applications.ResNet50(
    input_shape=(32,32,3),
    weights='imagenet',
    pooling='avg',
    include_top=False)
  
  output = tf.keras.layers.Dense(10, activation="softmax")(resnet_base.output)

  model = tf.keras.Model(inputs=[resnet_base.input], outputs=[output])

  return model

In [None]:
def create_resnet101():
  
  resnet_base = tf.keras.applications.ResNet101(
    input_shape=(32,32,3),
    weights='imagenet',
    pooling='avg',
    include_top=False)
  
  output = tf.keras.layers.Dense(10, activation="softmax")(resnet_base.output)

  model = tf.keras.Model(inputs=[resnet_base.input], outputs=[output])

  return model

In [None]:
def create_resnet152():
  
  resnet_base = tf.keras.applications.ResNet152(
    input_shape=(32,32,3),
    weights='imagenet',
    pooling='avg',
    include_top=False)
  
  output = tf.keras.layers.Dense(10, activation="softmax")(resnet_base.output)

  model = tf.keras.Model(inputs=[resnet_base.input], outputs=[output])

  return model

### Train ResNets

Train a ResNet50

In [None]:
with strategy.scope():
  resnet50 = create_resnet50()
  
  resnet50.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(),
              metrics=['accuracy'])

Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/resnet/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5


In [None]:
resnet50.fit(x_train, y_train, epochs=10, batch_size=256)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<tensorflow.python.keras.callbacks.History at 0x7f747e268450>

Train a Resnet101

In [None]:
with strategy.scope():
  resnet101 = create_resnet101()
  resnet101.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(),
              metrics=['accuracy'])



Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/resnet/resnet101_weights_tf_dim_ordering_tf_kernels_notop.h5


In [None]:
resnet101.fit(x_train, y_train, epochs=10, batch_size=256)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<tensorflow.python.keras.callbacks.History at 0x7f74791269d0>

Train a ResNet512


In [None]:
with strategy.scope():
  resnet152 = create_resnet152()
  resnet152.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(),
              metrics=['accuracy'])

Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/resnet/resnet152_weights_tf_dim_ordering_tf_kernels_notop.h5


In [None]:
resnet152.fit(x_train, y_train, epochs=10, batch_size=256)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<tensorflow.python.keras.callbacks.History at 0x7f7471fd3b50>

## Activation comparison

Function that calculates the CKA score between two matrices. 'Unrolls' the matrices beforehand

In [None]:
def calculate_CKA_for_two_matrices(activationA, activationB):
  '''Takes two activations A and B and computes the linear CKA to measure their similarity'''

  #unfold the activations, that is make a (n, h*w*c) representation
  shape = activationA.shape
  activationA = np.reshape(activationA, newshape=(shape[0], np.prod(shape[1:])))

  shape = activationB.shape
  activationB = np.reshape(activationB, newshape=(shape[0], np.prod(shape[1:])))

  #calculate the CKA score
  cka_score = CKA(activationA, activationB)

  del activationA
  del activationB

  return cka_score

Function that builds a function that returns all (intermediate) layer outputs for a given input

In [None]:
def get_all_layer_outputs_fn(model):
  '''Builds and returns function that returns the output of every (intermediate) layer'''

  return tf.keras.backend.function([model.layers[0].input],
                                  [l.output for l in model.layers[1:]])


In [None]:
def compare_activations(modelA, modelB, data_batch):
  '''
  Calculate a pairwise comparison of hidden representations and return a matrix
  '''
 
  #get function to get the output of every intermediate layer, for modelA and modelB
  intermediate_outputs_A = get_all_layer_outputs_fn(modelA)(data_batch)
  intermediate_outputs_B = get_all_layer_outputs_fn(modelB)(data_batch)

  #create a placeholder array
  result_array = np.zeros(shape=(len(intermediate_outputs_A), len(intermediate_outputs_B)))

  
  i = 0
  for outputA in tqdm.tqdm_notebook(intermediate_outputs_A):
    j = 0
    for outputB in tqdm.tqdm_notebook(intermediate_outputs_B):
      cka_score = calculate_CKA_for_two_matrices(outputA, outputB)
      result_array[i, j] = cka_score
      j+=1
    i+= 1

  return result_array

In [None]:
sim = compare_activations(resnet50, resnet101, x_train[:256])

# Further code

In [None]:
def compare_activations2(intermediate_outputs_A, intermediate_outputs_B):
  #create a placeholder array
  result_array = np.zeros(shape=(len(intermediate_outputs_A), len(intermediate_outputs_B)))

  
  i = 0
  for outputA in tqdm.tqdm_notebook(intermediate_outputs_A):
    j = 0
    for outputB in tqdm.tqdm_notebook(intermediate_outputs_B):
      cka_score = calculate_CKA_for_two_matrices(outputA, outputB)
      result_array[i, j] = cka_score
      j+=1
    i+= 1

  return result_array


In [None]:
import pickle

In [None]:
with open("/content/drive/MyDrive/activ_comparison/resnet50", "rb") as rp:
  resnet50 = pickle.load(rp)

In [None]:
with open("/content/drive/MyDrive/activ_comparison/resnet101", "rb") as rp:
  resnet101 = pickle.load(rp)

In [None]:
with open("/content/drive/MyDrive/activ_comparison/resnet152", "rb") as rp:
  resnet152 = pickle.load(rp)

In [None]:
sim = compare_activations2(resnet50, resnet101)
plt.figure(figsize=(30, 15), dpi=200)
axes = plt.imshow(sim, cmap='magma', vmin=0.0,vmax=1.0)
axes.axes.invert_yaxis()
plt.savefig("/content/drive/MyDrive/activ_comparison/r50_r101.png", dpi=400)

In [None]:
sim = compare_activations2(resnet50, resnet152)
plt.figure(figsize=(30, 15), dpi=200)
axes = plt.imshow(sim, cmap='magma', vmin=0.0,vmax=1.0)
axes.axes.invert_yaxis()
plt.savefig("/content/drive/MyDrive/activ_comparison/r50_r152.png", dpi=400)

In [None]:
sim = compare_activations2(resnet101, resnet152)
plt.figure(figsize=(30, 15), dpi=200)
axes = plt.imshow(sim, cmap='magma', vmin=0.0,vmax=1.0)
axes.axes.invert_yaxis()
plt.savefig("/content/drive/MyDrive/activ_comparison/r101_r152.png", dpi=400)

In [None]:
sim = compare_activations2(resnet50, resnet50)
plt.figure(figsize=(30, 15), dpi=200)
axes = plt.imshow(sim, cmap='magma', vmin=0.0,vmax=1.0)
axes.axes.invert_yaxis()
plt.savefig("/content/drive/MyDrive/activ_comparison/r50_r50.png", dpi=400)

In [None]:
sim = compare_activations2(resnet101, resnet101)
plt.figure(figsize=(30, 15), dpi=200)
axes = plt.imshow(sim, cmap='magma', vmin=0.0,vmax=1.0)
axes.axes.invert_yaxis()
plt.savefig("/content/drive/MyDrive/activ_comparison/r101_r101.png", dpi=400)

In [None]:
sim = compare_activations2(resnet152, resnet152)
plt.figure(figsize=(30, 15), dpi=200)
axes = plt.imshow(sim, cmap='magma', vmin=0.0,vmax=1.0)
axes.axes.invert_yaxis()
plt.savefig("/content/drive/MyDrive/activ_comparison/r152_r152.png", dpi=400)