In [None]:
import tensorflow as tf
import numpy as np
import sys

import matplotlib.pyplot as plt
import scipy.stats
import scipy.signal
import os
import pandas as pd
import torch
import time



```
# This is formatted as code
```

#Initialize the experiment : get list of saved model checkpoints and other parameters

In [None]:
num_analysis_trials = self.learn_config.num_analysis_episodes
analyze_every = self.learn_config.analyze_every

logging.info('Performing analysis in representation space w.r.t the %s split using %d episodes...',
              split, num_analysis_trials)
logging.info('Evaluating generalization on the %s split using %d episodes...',
              split, num_eval_trials)

# Get directory where model checkpoints are saved
eval_checkpoints_dir = self.learner_config.checkpoint_for_eval.split('model_')[0]
# List all checkpoint filenames, in the form model_{iteration}.ckpt
eval_checkpoints = [f for f in listdir(eval_checkpoints_dir) if (isfile(join(eval_checkpoints_dir, f)) and 'model_' in f and '.ckpt.index' in f)]
eval_checkpoints.sort()
for i in range(len(eval_checkpoints)):
  eval_checkpoints[i] = eval_checkpoints[i].strip('.index')
eval_checkpoints.sort()
# Now we have a list that is sorted alphabetically, but we need the numeric order of checkpoints
# i.e.: alphabetic order: model_1000, model_1500, model_500, numerical order: model_500, model_1000, model_1500
# Getting the list of saved iterations, in numerical order
for i in range(len(eval_checkpoints)):
  eval_checkpoints[i] = int(eval_checkpoints[i].strip('.cktp').strip('model_'))
eval_checkpoints.sort()
# Creating the list of filenames with the numerical order of saved iterations
for i in range(len(eval_checkpoints)):
  eval_checkpoints[i] = 'model_' + str(eval_checkpoints[i]) + '.ckpt'

# Analyze every x checkpoint
analysis_indices = list(range(0, len(eval_checkpoints), analyze_every))
eval_checkpoints = [eval_checkpoints[i] for i in analysis_indices]


# Check if the experiment is actually already completed

In [None]:
# Check if the experiment is actually already completed
if os.path.isfile(join(self.summary_dir, 'completed.npy')):
  sys.exit()

# Check if experiment has already started
elif os.path.isfile(join(self.summary_dir, 'activation_moments_trajectory.npy')):
  exp_dict = np.load(join(self.summary_dir, 'activation_moments_trajectory.npy'), allow_pickle=True)
  exp_dict = exp_dict.item()

  m1_models_list = exp_dict['m1_models_list']
  m2_models_list = exp_dict['m2_models_list']
  m3_models_list = exp_dict['m3_models_list']
  m4_models_list = exp_dict['m4_models_list']
  analyzed_models_list = exp_dict['analyzed_models_list']
  episode_list = exp_dict['episode_list']

  # Resuming the experiment after the last analyzed checkpoint
  last_analyzed_model = exp_dict['analyzed_models_list'][-1]
  last_analyzed_model_index = eval_checkpoints.index(last_analyzed_model)
  eval_checkpoints = eval_checkpoints[last_analyzed_model_index+1:]

# If experiment has not started yet
else:
  # Initializing the experiment data lists
  m1_models_list = []
  m2_models_list = []
  m3_models_list = []
  m4_models_list = []
  analyzed_models_list = []

  # Getting a static list of the support sets of all episodes in the test split.
  # This way, the same episodes will be used for all checkpoints of the model, to monitor the evolution of the metric.
  episodes = self.learners['test'].data.train_images
  episode_list = []
  for analysis_trial_num in range(num_analysis_trials):
    episode_tensors = self.sess.run(episodes)
    episode_list.append(episode_tensors)

#Example of how to get the intermediate activations in a neural network

In [None]:
def _four_layer_convnet(inputs,
                        scope,
                        reuse=tf.AUTO_REUSE,
                        params=None,
                        moments=None,
                        depth_multiplier=1.0,
                        backprop_through_moments=True,
                        use_bounded_activation=False,
                        keep_spatial_dims=False):
  """A four-layer-convnet architecture."""
  layer = tf.stop_gradient(inputs)
  model_params_keys, model_params_vars = [], []
  moments_keys, moments_vars = [], []
  intermediate_activations = [] # ADD CODE HERE

  with tf.variable_scope(scope, reuse=reuse):
    for i in range(4):
      with tf.variable_scope('layer_{}'.format(i), reuse=reuse):
        depth = int(64 * depth_multiplier)
        layer, conv_bn_params, conv_bn_moments = conv_bn(
            layer, [3, 3],
            depth,
            stride=1,
            params=params,
            moments=moments,
            backprop_through_moments=backprop_through_moments)
        model_params_keys.extend(conv_bn_params.keys())
        model_params_vars.extend(conv_bn_params.values())
        moments_keys.extend(conv_bn_moments.keys())
        moments_vars.extend(conv_bn_moments.values())

      if use_bounded_activation:
        layer = tf.nn.relu6(layer)
      else:
        layer = tf.nn.relu(layer)
      layer = tf.layers.max_pooling2d(layer, [2, 2], 2)
      intermediate_activations.append(layer) # ADD CODE HERE
      logging.info('Output of block %d: %s', i, layer.shape)

    model_params = collections.OrderedDict(
        zip(model_params_keys, model_params_vars))
    moments = collections.OrderedDict(zip(moments_keys, moments_vars))
    if not keep_spatial_dims:
      layer = tf.layers.flatten(layer)
    return_dict = {
        'embeddings': layer,
        'params': model_params,
        'moments': moments,
        'intermediate_activations': intermediate_activations # ADD CODE HERE
    }

    return return_dict

#Functions to compute the neural activation trajectory

In [None]:
def compute_m1_hat(h_vectors_):
  """
  Computes the first raw moment (sample mean) of first-order moments.
  First-order moments are the (sample) means of the individual activation features.
  For a matrix H \in R^{N \times d} of N activation vectors h, where h \in R^d :
  m1_hat = \frac{1}{d} \sum_{j=1}^{d} \frac{1}{N} \sum{i=1}^{N} H_{i,j}
  :param h_vectors_: the activation vectors
  :return:
  """
  H = tf.layers.flatten(h_vectors_)
  N = H.shape[0]
  d = H.shape[1]

  # First-order moment of the activation vectors : mean vector
  m1 = tf.math.reduce_sum(H, axis=0)
  m1 = tf.math.divide(m1, tf.compat.v1.to_float(N, name='ToFloat'))

  # m1_hat : first aggregated moment - first-order moment across feature dimensions of the mean activation vector
  m1_hat = tf.math.reduce_sum(m1, axis=0)
  m1_hat = tf.math.divide(m1_hat, tf.compat.v1.to_float(d, name='ToFloat'))

  return m1_hat

def compute_m2_hat(h_vectors_):
  """
  Computes the second raw moment (sample variance aroud zero) of first-order moments.
  First-order moments are the (sample) means of the individual activation features.
  For a matrix H \in R^{N \times d} of N activation vectors h, where h \in R^d :
  m2_hat = \frac{1}{d} \sum_{j=1}^{d} ( \frac{1}{N} \sum{i=1}^{N} H_{i,j} )^2
  :param h_vectors_: the activation vectors
  :return:
  """
  H = tf.layers.flatten(h_vectors_)
  N = H.shape[0]
  d = H.shape[1]

  # First-order moment of the activation vectors : mean vector
  m1 = tf.math.reduce_sum(H, axis=0)
  m1 = tf.math.divide(m1, tf.compat.v1.to_float(N, name='ToFloat'))

  # m2_hat : second aggregated moment - second-order moment across feature dimensions of the mean activation vector
  m2_hat = tf.math.square(m1)
  m2_hat = tf.math.reduce_sum(m2_hat, axis=0)
  m2_hat = tf.math.divide(m2_hat, tf.compat.v1.to_float(d, name='ToFloat'))

  return m2_hat

def compute_m3_hat(h_vectors_):
  """
  Computes the first raw moment (sample mean) of the diagonal elements of the matrix of second-order moments (auto-correlation matrix).
  Diagonal second-order moments are the (sample) variances of the individual activation features.
  For a matrix H \in R^{N \times d} of N activation vectors h, where h \in R^d :
  m3_hat = \frac{1}{d} \sum_{j=1}^{d} \frac{1}{N} \sum{i=1}^{N} H_{i,j}^2
  :param h_vectors_: the activation vectors
  :return:
  """
  H = tf.layers.flatten(h_vectors_)
  N = H.shape[0]
  d = H.shape[1]

  # m3_hat : third aggregated moment - first-order moment across feature dimensions of the diagonal elements of the auto-correlation matrix
  m3_hat = tf.math.square(H)
  m3_hat = tf.math.reduce_sum(m3_hat, axis=0)
  m3_hat = tf.math.divide(m3_hat, tf.compat.v1.to_float(N, name='ToFloat'))
  m3_hat = tf.math.reduce_sum(m3_hat, axis=0)
  m3_hat = tf.math.divide(m3_hat, tf.compat.v1.to_float(d, name='ToFloat'))

  return m3_hat

def compute_m4_hat(h_vectors_, max_block_size=500):
  """
  Computes the first raw moment (sample mean) of the non-diagonal elements of the matrix of second-order moments.
  Non-diagonal second-order moments are the (sample) covariances of the individual activation features.
  For a matrix H \in R^{N \times d} of N activation vectors h, where h \in R^d :
  m4_hat = \frac{1}{d^2 - d} \sum_{k=1}^{d} \sum_{j=1, j \neq k}^{d} \frac{1}{N} \sum{i=1}^{N} H_{i,j} \time H_{i,k}
  :param h_vectors_: the activation vectors
  :param max_block_size: size of smaller blocks to break the matrix multiplication for the covariances
  :return:
  """
  H = tf.layers.flatten(h_vectors_)
  N = H.shape[0]
  d = H.shape[1]

  m4_hat = 0.0
  num_blocks = d.value / max_block_size
  for i in range(np.int(np.floor(num_blocks))):
    B = H[:, i * max_block_size:(i + 1) * max_block_size]
    m4_hat = m4_hat + tf.math.reduce_sum(tf.matmul(tf.transpose(B), B), axis=None)
  if num_blocks % 1 > 0:
    B = H[:, np.int(np.floor(num_blocks)) * max_block_size:]
    m4_hat = m4_hat + tf.math.reduce_sum(tf.matmul(tf.transpose(B), B), axis=None)
  m4_hat = m4_hat - tf.math.reduce_sum(tf.math.square(H), axis=None)
  m4_hat = tf.math.divide(m4_hat, tf.compat.v1.to_float(N * (d * (d - 1)), name='ToFloat'))

  return m4_hat

# Compute the neural activation trajectory

In [None]:
# For each checkpoint, running the analysis on each episode
for eval_checkpoint in eval_checkpoints:
  # Restoring the model parameters from the saved checkpoint
  self.saver.restore(self.sess, join(eval_checkpoints_dir, eval_checkpoint))
  m1_tasks_list = []
  m2_tasks_list = []
  m3_tasks_list = []
  m4_tasks_list = []
  with tf.Session() as sess:
    # For each episode, computing the average inner product between the representations of the support examples
    for analysis_trial_num in range(num_analysis_trials):
      # Getting the activations of all the layers of the feature network
      m1_layers_list = []
      m2_layers_list = []
      m3_layers_list = []
      m4_layers_list = []

      intermediate_activations = self.embedding_fn(episode_list[analysis_trial_num], is_training=False)['intermediate_activations']
      for layer_id in range(len(intermediate_activations)):
        m1 = compute_m1_hat(intermediate_activations[layer_id])
        m2 = compute_m2_hat(intermediate_activations[layer_id])
        m3 = compute_m3_hat(intermediate_activations[layer_id])
        m4 = compute_m4_hat(intermediate_activations[layer_id])

        m1_layers_list.append(m1)
        m2_layers_list.append(m2)
        m3_layers_list.append(m3)
        m4_layers_list.append(m4)

      m1_tasks_list.append(m1_layers_list)
      m2_tasks_list.append(m2_layers_list)
      m3_tasks_list.append(m3_layers_list)
      m4_tasks_list.append(m4_layers_list)

    m1_tasks_list_values, m2_tasks_list_values, m3_tasks_list_values, m4_tasks_list_values = self.sess.run([m1_tasks_list,
                                                                                                            m2_tasks_list,
                                                                                                            m3_tasks_list,
                                                                                                            m4_tasks_list])
  m1_models_list.append(m1_tasks_list_values)
  m2_models_list.append(m2_tasks_list_values)
  m3_models_list.append(m3_tasks_list_values)
  m4_models_list.append(m4_tasks_list_values)

  analyzed_models_list.append(eval_checkpoint)
  print('Iteration:\t ' + str(eval_checkpoint))

  np.save(join(self.summary_dir, 'activation_moments_trajectory.npy'), {'m1_models_list': m1_models_list,
                                                                  'm2_models_list': m2_models_list,
                                                                  'm3_models_list': m3_models_list,
                                                                  'm4_models_list': m4_models_list,
                                                                  'analyzed_models_list': analyzed_models_list,
                                                                  'episode_list': episode_list})

# The experiment is complete. Saving empty file "completed.npy" to signal that the experiment is done.
np.save(join(self.summary_dir, 'completed.npy'), None)

# Comparing the target and source neural activation trajectories

## Utilities

In [None]:
def temporal_smoothing_v3(metric, step=5):
    if step == 0:
        return metric
    metric_array = metric
    l = len(metric_array)
    avg_metric = np.zeros(l)
    counts = np.zeros(l)
    
    for i in range(-step,step+1):
        idx_start_sum = np.max([0, i])
        idx_end_sum = np.min([i+l-1, l-1])
        
        idx_start_add = l - idx_end_sum - 1
        idx_end_add = l - idx_start_sum - 1
        
        avg_metric[idx_start_sum:idx_end_sum+1] += metric_array[idx_start_add:idx_end_add+1]
        counts[idx_start_sum:idx_end_sum+1] += 1

    avg_metric /= counts
        
    return avg_metric

In [None]:
def plot_acc_train_valid(acc_train, iters_train, acc_valid, iters_valid):
    plt.figure(figsize=(15,8))
    plt.subplot(1,2,1)
    plt.title('Train Accuracy')
    plt.plot(iters_train, acc_train, label='$Acc_{train}$')
    ax = plt.gca()
    ax.axvline(x=iters_train[np.argmax(acc_train)], linestyle='--', linewidth=2.0, c='C1')
    plt.legend(fontsize=15)
    plt.grid()
    plt.xlabel('Training iter', fontsize=15)
    plt.ylabel('Mean Accuracy', fontsize=15)
    plt.xticks(fontsize=15)
    plt.yticks(fontsize=15)

    plt.subplot(1,2,2)
    plt.title('Valid Accuracy')
    plt.plot(iters_valid, acc_valid, label='$Acc_{valid}$')
    ax = plt.gca()
    ax.axvline(x=iters_valid[np.argmax(acc_valid)], linestyle='--', linewidth=2.0, c='C1')
    plt.legend(fontsize=15)
    plt.grid()
    plt.xlabel('Training iter', fontsize=15)
    plt.ylabel('Mean Accuracy', fontsize=15)
    plt.xticks(fontsize=15)
    plt.yticks(fontsize=15)


In [None]:
def plot_acc_target(acc_target, iters_target, target_dataset):
    plt.figure(figsize=(8,8))
    plt.title('Target Accuracy')
    plt.plot(iters_target, acc_target, label='$Acc_{target}$ - ' + target_dataset)
    plt.legend(fontsize=15)
    plt.grid()
    plt.xlabel('Training iter', fontsize=15)
    plt.ylabel('Mean Accuracy', fontsize=15)
    plt.xticks(fontsize=15)
    plt.yticks(fontsize=15)
    ax = plt.gca()
    ax.axvline(x=iters_target[np.argmax(acc_target)], linestyle='--', linewidth=2.0, c='C1')


In [None]:
def plot_activation_moments(moments_source, moments_source_iters, moments_target, moments_target_iters, acc_valid, layers, show_corr, task_id=0):
    for layer in layers:
        plt.figure(figsize=(15,6))
        ax_list = []
        for moment_idx in range(len(moments_target)):
            ax = plt.subplot(1,len(moments_target),moment_idx+1)
            ax_list.append(ax)
            ax.title.set_text('Layer ' + str(layer+1) + ' - Moment ' + str(moment_idx+1))
            ax.plot(moments_target_iters[0:moments_target[moment_idx].shape[0]], moments_target[moment_idx, :, task_id, layer], label='target moments')
            ax.plot(moments_source_iters[0:moments_source[moment_idx].shape[0]], moments_source[moment_idx, :, layer], label='source moments')

            
            max_length = np.min([len(moments_target[moment_idx, :, task_id, layer]),
                                       len(moments_source[moment_idx, :, layer])])
            if show_corr:                
                corr = scipy.stats.pearsonr(moments_target[moment_idx, :max_length, task_id, layer], 
                                            moments_source[moment_idx, :max_length, layer])

                textstr = "Corr= " + str(float("{0:.2f}".format(corr[0])))
                props = dict(boxstyle='round', facecolor='white', alpha=1.0)
                ax.text(0.5, 0.5, textstr, transform=ax.transAxes, fontsize=15,
                        verticalalignment='top', bbox=props, zorder=1000)
            ax.legend(loc='best')

            # Compute decorrelation time
            #peak = compute_decorrelation_time(moments_target[moment_idx, :max_length, task_id, layer], moments_source[moment_idx, :max_length, layer])
            #ax.axvline(x=moments_target_iters[peak], linestyle='--', linewidth=2.0, c='C1')
        for axes in ax_list: axes.grid()

In [None]:
def get_moments(filename):
    analysis_moments = np.load(filename, allow_pickle=True)
    analysis_moments = analysis_moments.item()
    iters = analysis_moments['analyzed_models_list']
    for i in range(len(iters)):
        iters[i] = iters[i].strip('.ckpt')
        iters[i] = iters[i].split('_')[-1]
        iters[i] = int(iters[i])
    iters = np.asarray(iters)
    m1 = np.asarray(analysis_moments['m1_models_list'])
    m2 = np.asarray(analysis_moments['m2_models_list'])
    m3 = np.asarray(analysis_moments['m3_models_list'])
    m4 = np.asarray(analysis_moments['m4_models_list'])
    moments = np.asarray([m1, m2, m3, m4])

    return moments, iters


In [None]:
def load_target_acc(exp_dir_target_acc, algo, source_dataset, run, target_dataset):
    acc_target = None
    iters_target = None
    target_acc_found = False
    analysis_target_legacy = None
    analysis_target_new = None
    
    if algo == 'maml'and source_dataset == 'mini-imagenet':
        dir_target_acc_maml_mini_imagenet = '/home/user1/PhD-E2020/FS_Model_Selection/Meta-Dataset/meta-dataset/exp/summaries/paper_exps/summaries/representation_analyses/target_accuracies/'
        #analysis = np.load(dir_target_acc_maml_mini_imagenet + 'maml_mini-imagenet_15-extra-steps_run-' + str(run) + '_analysis_' + target_dataset + '/analysis.npy')
        analysis = np.load(dir_target_acc_maml_mini_imagenet + 'maml-fo_5way1shot_mini-imagenet_from_scratch_convnet_target_accuracy_analysis_' + target_dataset + '/target_accuracy.npy')
        acc_target = np.asarray(analysis.item()['mean_acc_models_list'])
        #mean_acc = np.mean(mean_acc_models_list, axis=1)
        iters = analysis.item()['analyzed_models_list']
        for i in range(len(iters)):
            iters[i] = iters[i].strip('.ckpt')
            iters[i] = iters[i].split('_')[-1]
            iters[i] = int(iters[i])
        iters_target = np.asarray(iters)
        
    else:
        legacy_filename = exp_dir_target_acc + algo + '_' + source_dataset + '_run-' + str(run) + '_analysis_' + target_dataset + '/analysis.npy'
        filename = exp_dir_target_acc + 'target_accuracies/' + algo + '_' + source_dataset + '_run-' + str(run) + '_target_accuracy_analysis_' + target_dataset + '/target_accuracy.npy'
        # First check the legacy files
        if os.path.isfile(legacy_filename):
            analysis_target_legacy = np.load(legacy_filename).item()
            target_acc_found = True
        # If legacy accuracy not found, check new files
        if os.path.isfile(filename):
            analysis_target_new = np.load(filename).item()
            target_acc_found = True

        if target_acc_found:
            # Taking the experiment that ran for the longest
            if analysis_target_new is None:
                analysis_target = analysis_target_legacy
            elif analysis_target_legacy is None:
                analysis_target = analysis_target_new
            else:
                if len(analysis_target_legacy['analyzed_models_list']) > len(analysis_target_new['analyzed_models_list']):
                    analysis_target = analysis_target_legacy
                else:
                    analysis_target = analysis_target_new
            acc_target = np.asarray(analysis_target['mean_acc_models_list'])
            # Formatting the iterations array
            iters_target = analysis_target['analyzed_models_list']
            for i in range(len(iters_target)): iters_target[i] = int(iters_target[i].strip('.ckpt').split('_')[-1])
            iters_target = np.asarray(iters_target)

        else:
            print('Target accuracy not found !')
    
    return acc_target, iters_target


## Function for identifying the critical layer


In [None]:
def critical_layer_min_corr_avg_over_moments(task_idx, moments_target, moments_target_iters, moments_source, acc_valid, iters_valid):
    min_layer_corr = np.infty
    critical_layer_idx = None
    for layer_idx in range(moments_target.shape[3]):
        corr_layer = 0.0
        # Averaging the moment correlations
        for moment_idx in range(moments_target.shape[0]):
            max_length = np.min([len(moments_target[moment_idx, :, task_idx, layer_idx]),
                                 len(moments_source[moment_idx, :, layer_idx])])
            # Note : Only condidering the trajectory between init and iter max valid, inclusively
            limit = np.argmin(np.abs(moments_target_iters - iters_valid[np.argmax(acc_valid)])) + 1
            limit = np.min([limit, max_length])
            max_length = limit

            corr, p_value = scipy.stats.pearsonr(
                moments_target[moment_idx, :max_length, task_idx, layer_idx],
                moments_source[moment_idx, :max_length, layer_idx])
            corr_layer += corr
        corr_layer /= moments_target.shape[0]
        if corr_layer < min_layer_corr:
            min_layer_corr = corr_layer
            critical_layer_idx = layer_idx
            
    return critical_layer_idx

##Function for identifying the critical moment

In [None]:
def critical_moment_min_corr_before_max_valid(task_idx, critical_layer_idx, moments_target, moments_target_iters, moments_source, acc_valid, iters_valid):
    min_moment_corr = np.infty
    critical_moment_idx = None
    for moment_idx in range(moments_target.shape[0]):
        max_length = np.min([len(moments_target[moment_idx, :, task_idx, critical_layer_idx]),
                             len(moments_source[moment_idx, :, critical_layer_idx])])
        # Note : Only condidering the trajectory between init and iter max valid, inclusively
        limit = np.argmin(np.abs(moments_target_iters - iters_valid[np.argmax(acc_valid)])) + 1
        limit = np.min([limit, max_length])
        max_length = limit

        corr, p_value = scipy.stats.pearsonr(
            moments_target[moment_idx, :max_length, task_idx, critical_layer_idx],
            moments_source[moment_idx, :max_length, critical_layer_idx])
        if corr < min_moment_corr:
            min_moment_corr = corr
            critical_moment_idx = moment_idx
            
        return critical_moment_idx

## Function for computing stopping_time score

In [None]:
def compute_score_rightside_neg_corr_windows(task_idx, critical_layer_idx, critical_moment_idx, moments_target, moments_target_iters, moments_source, acc_valid, iters_valid, multiply_score_by_interval):
    max_length = np.min([len(moments_target[critical_moment_idx, :, task_idx, critical_layer_idx]),
                             len(moments_source[critical_moment_idx, :, critical_layer_idx])])
    limit = np.argmin(np.abs(moments_target_iters - iters_valid[np.argmax(acc_valid)])) + 1
    limit = np.min([limit, max_length])
    max_length = limit

    x = moments_target[critical_moment_idx, :max_length, task_idx, critical_layer_idx]
    y = moments_source[critical_moment_idx, :max_length, critical_layer_idx]
    t_0 = 0
    max_score = - np.infty
    scores = []
    for t in range(len(x)):
        interval_neg = range(t, len(x))
        score_neg = 0.0
        if len(interval_neg) > 1:
            corr_neg, p_value_neg = scipy.stats.pearsonr(x[t:], y[t:])
            if multiply_score_by_interval:
                score_neg = - (corr_neg * len(interval_neg))
            else:
                score_neg = - corr_neg
        # Only computing the score for the right region of negative correlation
        score = score_neg
        scores.append(score)
    
    return scores

##Function for performing the analysis and computing the performances

In [None]:
# critical layer (CL) : Where correlation between source and target moments is lowest, averaged on the four moments
#     difference with v1 =>>>  time window :  trajectory before iter max valid acc
# critical moment (CM) : At CL, the moment where correlation is lowest between source and target
#   difference with v1 =>>>  time window :  trajectory before iter max valid acc
# critical time (CT) : At CL and CM, the time where the right region of negative correlation (-1*corr) multiplied by width, is highest.
#                     time window :  trajectory before iter max valid acc
# CT limit : iter max valid acc
def negative_correlation_before_max_valid(moments_target, moments_target_iters, moments_source, acc_valid, iters_valid, acc_target, iters_target, show_stopping=False, multiply_score_by_interval=True, stop_at_first_maximum=False, force_layer_moment_inspection=None, find_layer_and_moment_by_score=False,
                                         critical_layer_criterion='critical_layer_min_corr_avg_over_moments',
                                         critical_moment_criterion='critical_moment_min_corr_before_max_valid',
                                         critical_time_criterion='max_score_rightside_neg_corr_windows'):
    perf_avg = 0.0
    list_perf = []
    list_iter_stopping = []
    # NOTE : target moments arrays are of the shape : [num_moments, num_iters, num_tasks, num_layers]
    #        source moments arrays are of the shape : [num_moments, num_iters, num_layers]
    # Evaluating the method on every target task
    for task_idx in range(moments_target.shape[2]):
        layer_indices = list(range(moments_target.shape[3]))
        moment_indices = list(range(moments_target.shape[0]))
        
        # 1. Identifying at which layer the distributional shift is highest
        if critical_layer_criterion == 'critical_layer_min_corr_avg_over_moments':
            critical_layer_idx = critical_layer_min_corr_avg_over_moments(task_idx, moments_target, moments_target_iters, moments_source, acc_valid, iters_valid)

        elif critical_layer_criterion == 'max_score_rightside_neg_corr_windows':
            max_score = - np.infty
            indices_max_score = None
            for layer_idx in layer_indices:
                for moment_idx in moment_indices:
                    scores = compute_score_rightside_neg_corr_windows(task_idx, layer_idx, moment_idx, moments_target, moments_target_iters, moments_source, acc_valid, iters_valid, multiply_score_by_interval)
                    score = np.max(scores)
                    if score > max_score:
                        max_score = score
                        indices_max_score = [layer_idx, moment_idx]
            critical_layer_idx = indices_max_score[0]
            
        else:
            print('Provided critical_layer_criterion NOT SUPPORTED !')
            sys.exit()

        # 2. Identifying which moment drives the distributional shift the most
        if critical_moment_criterion == 'critical_moment_min_corr_before_max_valid':
            critical_moment_idx = critical_moment_min_corr_before_max_valid(task_idx, critical_layer_idx, moments_target, moments_target_iters, moments_source, acc_valid, iters_valid)

        elif critical_moment_criterion == 'max_score_rightside_neg_corr_windows':
            if critical_layer_criterion == 'max_score_rightside_neg_corr_windows':
                critical_moment_idx = indices_max_score[1]
            else:
                max_score = - np.infty
                idx_max_score = None
                for moment_idx in moment_indices:
                    scores = compute_score_rightside_neg_corr_windows(task_idx, critical_layer_idx, moment_idx, moments_target, moments_target_iters, moments_source, acc_valid, iters_valid, multiply_score_by_interval)
                    score = np.max(scores)
                    if score > max_score:
                        max_score = score
                        idx_max_score = moment_idx
                critical_moment_idx = idx_max_score

        else:
            print('Provided critical_moment_criterion NOT SUPPORTED !')
            sys.exit()

        # For debugging purposes
        if force_layer_moment_inspection is not None:
            critical_layer_idx = force_layer_moment_inspection[0]
            critical_moment_idx = force_layer_moment_inspection[1]
                
        # 3. Identifying at which time the critical moment, at the critical layer, start to decorrelate
        # Note : Only condidering the trajectory between init and iter max valid, inclusively
        if critical_time_criterion == 'max_score_rightside_neg_corr_windows':
            scores = compute_score_rightside_neg_corr_windows(task_idx, critical_layer_idx, critical_moment_idx, moments_target, moments_target_iters, moments_source, acc_valid, iters_valid, multiply_score_by_interval)
        else:
            print('Provided critical_time_criterion NOT SUPPORTED !')
            sys.exit()

        t_stopping = np.argmax(scores)
        max_score = np.max(scores)
        
        if show_stopping and task_idx < 10:
            max_length = np.min([len(moments_target[critical_moment_idx, :, task_idx, critical_layer_idx]),
                                     len(moments_source[critical_moment_idx, :, critical_layer_idx])])
            limit = np.argmin(np.abs(moments_target_iters - iters_valid[np.argmax(acc_valid)])) + 1
            limit = np.min([limit, max_length])
            max_length = limit
            
            x = moments_target[critical_moment_idx, :max_length, task_idx, critical_layer_idx]
            y = moments_source[critical_moment_idx, :max_length, critical_layer_idx]
            
            if task_idx == 0:
            #if True: #quick debug / displaying. restore line above when done !!!
                plt.figure(figsize=(16,8))
                plt.title('scores - target moments - source moments')
                plt.xticks(fontsize=20)
            plt.subplot(1,3,1)
            plt.plot(moments_target_iters[:max_length], scores, label=(critical_layer_idx, critical_moment_idx, t_stopping, task_idx))
            plt.scatter(moments_target_iters[t_stopping], max_score, s=50)
            plt.legend()
            plt.grid()
            plt.subplot(1,3,2)
            plt.plot(moments_target_iters[:max_length], x, label=task_idx)
            plt.legend()
            plt.grid()
            plt.subplot(1,3,3)
            plt.plot(moments_target_iters[:max_length], y, label=task_idx)
            plt.legend()
            plt.grid()
        # Our method stops before or at the maximum validation, not after
        iter_stopping = np.min([moments_target_iters[t_stopping], iters_valid[np.argmax(acc_valid)]])

        # 4. Computing the generalization performance when using the method on the current target task
        #perf_task = np.mean(acc_target, axis=1)[np.argmin(np.abs(iters_target - iter_stopping))]
        perf_task = acc_target[np.argmin(np.abs(iters_target - iter_stopping))]
        perf_avg += perf_task
        list_perf.append(perf_task)
        list_iter_stopping.append(iter_stopping)

    # Averaging the performance of our method on the current target dataset
    perf_avg /= moments_target.shape[2]
    # Computing the performance of the validation baseline
    iter_max_valid_acc = iters_valid[np.argmax(acc_valid)]
    #perf_valid = np.mean(acc_target, axis=1)[np.argmin(np.abs(iters_target - iter_max_valid_acc))]
    perf_valid = acc_target[np.argmin(np.abs(iters_target - iter_max_valid_acc))]

    # Printing the performance
    print('TARGET : ' + target_dataset)
    print('our method : ' + str(perf_avg) + ' valid baseline : ' + str(perf_valid))
    
    return perf_avg, perf_valid, critical_layer_idx, critical_moment_idx, list_perf, list_iter_stopping

## Perform early-stopping

In [None]:
# Directories containing raw data
exp_dir_activation_trajectory = '/home/user1/PhD-A2021/Neural_Computation/exps/baselines/BACKUP_activation_moments_trajectory_3/'
exp_dir_target_acc = '/home/user1/PhD-E2020/FS_Model_Selection/Meta-Dataset/meta-dataset/exp/summaries/paper_exps/summaries/representation_analyses/'
exp_dir_valid_acc = '/home/user1/PhD-E2020/FS_Model_Selection/Meta-Dataset/meta-dataset/exp/summaries/paper_exps/valid_acc/'
exp_dir_train_acc = '/home/user1/PhD-E2020/FS_Model_Selection/Meta-Dataset/meta-dataset/exp/summaries/paper_exps/train_acc/'

# Fields for the experiments
algos=['maml', 'prototypical', 'matching']
source_datasets=['cu_birds', 'dtd', 'aircraft', 'omniglot', 'vgg_flower', 'imagenet', 'mini-imagenet', 'quickdraw']
target_datasets=['cu_birds', 'dtd', 'aircraft', 'omniglot', 'vgg_flower', 'ilsvrc_2012', 'mini_imagenet', 'quickdraw', 'traffic_sign']

training_runs_dict = {'maml' : {'cu_birds': 4, 'dtd': 3, 'aircraft': 2, 'omniglot': 2, 'vgg_flower': 4, 'imagenet': 1, 'mini-imagenet': 1, 'quickdraw': 5},
             'prototypical': {'cu_birds': 1, 'dtd': 3, 'aircraft': 4, 'omniglot': 2, 'vgg_flower': 5, 'imagenet': 1, 'mini-imagenet': 1, 'quickdraw': 1},
             'matching': {'cu_birds': 1, 'dtd': 4, 'aircraft': 1, 'omniglot': 1, 'vgg_flower': 1, 'imagenet': 2, 'mini-imagenet': 1, 'quickdraw': 1}
                    }
layers = [0,1,2,3]

# If insecting a single experiment, i.e. (algo, source dataset, target dataset)
inspect_single_exp = False
if inspect_single_exp:
    algos = ['matching']
    source_datasets = ['quickdraw']
    target_datasets = ['omniglot']
    force_layer_moment_inspection = None
    #force_layer_moment_inspection = [3,0]

force_layer_moment_inspection = None
    
show_plots = False
# Options for data analysis
plot_accuracy = show_plots
plot_moments = show_plots
show_stopping = show_plots
smooth_acc_train = True
smooth_acc_valid = True
smooth_acc_target = True
smooth_moments = False
multiply_score_by_interval = True
stop_at_first_maximum = False
num_smoothing_steps = 3
show_corr = True
task_id = 0  # inspecting a single target task (debugging)

stopping_criterion = 'negative_correlation_before_max_valid'
critical_layer_criterion='max_score_rightside_neg_corr_windows'
critical_moment_criterion='max_score_rightside_neg_corr_windows'
critical_time_criterion='max_score_rightside_neg_corr_windows'


analysis_arguments = {}
analysis_arguments['smooth_acc_train'] = smooth_acc_train
analysis_arguments['smooth_acc_valid'] = smooth_acc_valid
analysis_arguments['smooth_acc_target'] = smooth_acc_target
analysis_arguments['smooth_moments'] = smooth_moments
analysis_arguments['multiply_score_by_interval'] = multiply_score_by_interval
analysis_arguments['stop_at_first_maximum'] = stop_at_first_maximum
analysis_arguments['num_smoothing_steps'] = num_smoothing_steps
analysis_arguments['stopping_criterion'] = stopping_criterion

analysis_arguments['critical_layer_criterion'] = critical_layer_criterion
analysis_arguments['critical_moment_criterion'] = critical_moment_criterion
analysis_arguments['critical_time_criterion'] = critical_time_criterion

In [None]:
performance_dict = {}
start = time.time()
for algo in algos:
    print('*********************************************************')
    print('ALGO : ' + algo)
    performance_dict[algo] = {}
    for source_dataset in source_datasets:
        print('---------------------------------------------------------')
        print('SOURCE : ' + source_dataset)
        print('---------------------------------------------------------')
        run = training_runs_dict[algo][source_dataset]
        if run is None:
            print('Model not trained yet.')
        else:
            # Loading train accuracy
            try:
                results_train = np.loadtxt(exp_dir_train_acc + 'run-' + algo + '_' + source_dataset + '_run-' + str(run) + '-tag-train_acc.csv', delimiter=',', skiprows=1)
                acc_train = results_train[:, 2]
                if smooth_acc_train: 
                    acc_train = temporal_smoothing_v3(acc_train, step=num_smoothing_steps)
                iters_train = results_train[:, 1].astype(int)
            except:
                print('Could not load train accuracy')
            # Loading valid accuracy
            try:
                results_valid = np.loadtxt(exp_dir_valid_acc + 'run-' + algo + '_' + source_dataset + '_run-' + str(
                    run) + '-tag-mean valid acc.csv', delimiter=',', skiprows=1)
                acc_valid = results_valid[:, 2]
                if smooth_acc_valid: 
                    acc_valid = temporal_smoothing_v3(acc_valid, step=num_smoothing_steps)
                iters_valid = results_valid[:, 1].astype(int)
            except:
                print('Could not load valid accuracy')
                continue
            # Plotting the training and validation accuracy
            if plot_accuracy:
                try:
                    plot_acc_train_valid(acc_train, iters_train, acc_valid, iters_valid)
                except:
                    print('Train or Valid accuracy is missing.')
            # Loading source activation moments
            try:
                dataset_valid_moments = source_dataset

                moments_source, moments_source_iters = get_moments(
                    exp_dir_activation_trajectory + algo + '_' + source_dataset + '_run-' + str(
                        run) + '_activation_moments_trajectory_analysis_' + dataset_valid_moments + '/activation_moments_trajectory.npy')

                # Averaging source moments across all tasks
                moments_source = np.mean(moments_source, axis=2)
                
                if smooth_moments:
                    for i in range(moments_source.shape[0]):
                        for j in range(moments_source.shape[2]):
                            moments_source[i, :, j] = temporal_smoothing_v3(moments_source[i, :, j], step=num_smoothing_steps)

            except:
                print('Could not load the source activation moments')
                continue

            performance_dict[algo][source_dataset] = {}
            # Looping over all target datasets
            for target_dataset in target_datasets:
                # Loading target accuracy
                acc_target, iters_target = load_target_acc(exp_dir_target_acc, algo, source_dataset, run,
                                                           target_dataset)
                
                acc_target = np.mean(acc_target, axis=1)
                if smooth_acc_target:
                    acc_target = temporal_smoothing_v3(acc_target, step=num_smoothing_steps)
                
                if (acc_target is None) or (iters_target is None):
                    print('Could not load the target accuracy.')
                    continue
                          
                if plot_accuracy and (acc_target is not None) and (iters_target is not None):
                    plot_acc_target(acc_target, iters_target, target_dataset)
                
                # Loading target activation moments trajectory
                try:
                    moments_target, moments_target_iters = get_moments(
                        exp_dir_activation_trajectory + algo + '_' + source_dataset + '_run-' + str(
                            run) + '_activation_moments_trajectory_analysis_' + target_dataset + '/activation_moments_trajectory.npy')
                except:
                    print('Could not load the target activation moments.')
                    continue
                    
                if smooth_moments:
                    for i in range(moments_target.shape[0]):
                        for j in range(moments_target.shape[2]):
                            for k in range(moments_target.shape[3]):
                                moments_target[i, :, j, k] = temporal_smoothing_v3(moments_target[i, :, j, k], step=num_smoothing_steps)
                    
                if plot_moments:
                    plot_activation_moments(moments_source, moments_source_iters, moments_target, moments_target_iters,
                                            acc_valid, layers, show_corr, task_id)
                # Check if moment trajectories have more than one point
                if moments_target.shape[1] <= 1 or moments_source.shape[1] <= 1:
                    print('Moment trajectories less than two points. Cannot evaluate method.')
                    continue

                # EVALUATING THE METHOD PERFORMANCE
                performance_dict[algo][source_dataset][target_dataset] = {}

                if stopping_criterion == 'negative_correlation_before_max_valid':
                    perf_avg, perf_valid, critical_layer_idx, critical_moment_idx, list_perf, list_iter_stopping = negative_correlation_before_max_valid(moments_target, moments_target_iters, moments_source, acc_valid, iters_valid, acc_target, iters_target, show_stopping=show_stopping, multiply_score_by_interval=multiply_score_by_interval, stop_at_first_maximum=stop_at_first_maximum, force_layer_moment_inspection=force_layer_moment_inspection,
                                                                                                                          critical_layer_criterion=critical_layer_criterion,
                                                                                                                          critical_moment_criterion=critical_moment_criterion,
                                                                                                                          critical_time_criterion=critical_time_criterion)
                    performance_dict[algo][source_dataset][target_dataset]['critical_layer_idx'] = critical_layer_idx
                    performance_dict[algo][source_dataset][target_dataset]['critical_moment_idx'] = critical_moment_idx
                    
                    performance_dict[algo][source_dataset][target_dataset]['list_perf'] = list_perf
                    performance_dict[algo][source_dataset][target_dataset]['list_iter_stopping'] = list_iter_stopping

                else:
                    print('Invalid stopping criterion ! Options :\n[decorrelation_criterion,\ndistributional_shift_inflextion,\ndecorrelation_before_max_valid,\nnegative_correlation_averaged_before_max_valid]')
                    sys.exit()
                
                # Saving the performances in the dictionary
                performance_dict[algo][source_dataset][target_dataset]['our_method'] = perf_avg
                performance_dict[algo][source_dataset][target_dataset]['baseline_validation'] = perf_valid
                performance_dict[algo][source_dataset][target_dataset]['acc_target_max'] = np.max(acc_target)
                performance_dict[algo][source_dataset][target_dataset]['acc_target_min'] = np.min(acc_target)            
                
print('\n\n\n###################################################')
print('ANALYSIS COMPLETE !')
end = time.time()
print('Running time : ' + str(end - start))