In [10]:
import numpy as np
import sys

import matplotlib.pyplot as plt
import scipy.stats
import scipy.signal
import os
import pandas as pd
import torch
import time

In [11]:
# Functions to compute the aggregated moments

In [12]:
def compute_m1_hat(h_vectors_):
  """
  Computes the first raw moment (sample mean) of first-order moments.
  First-order moments are the (sample) means of the individual activation features.
  For a matrix H \in R^{N \times d} of N activation vectors h, where h \in R^d :
  m1_hat = \frac{1}{d} \sum_{j=1}^{d} \frac{1}{N} \sum{i=1}^{N} H_{i,j}
  :param h_vectors_: the activation vectors
  :return:
  """
  H = h_vectors_.view(size=[h_vectors_.shape[0], -1])
  N = H.shape[0]
  d = H.shape[1]

  # First-order moment of the activation vectors : mean vector
  m1 = torch.mean(H, dim=0)

  # m1_hat : first aggregated moment - first-order moment across feature dimensions of the mean activation vector
  m1_hat = torch.mean(m1)

  return m1_hat

def compute_m2_hat(h_vectors_):
  """
  Computes the second raw moment (sample variance aroud zero) of first-order moments.
  First-order moments are the (sample) means of the individual activation features.
  For a matrix H \in R^{N \times d} of N activation vectors h, where h \in R^d :
  m2_hat = \frac{1}{d} \sum_{j=1}^{d} ( \frac{1}{N} \sum{i=1}^{N} H_{i,j} )^2
  :param h_vectors_: the activation vectors
  :return:
  """
  H = h_vectors_.view(size=[h_vectors_.shape[0], -1])
  N = H.shape[0]
  d = H.shape[1]

  # First-order moment of the activation vectors : mean vector
  m1 = torch.mean(H, dim=0)

  # m2_hat : second aggregated moment - second-order moment across feature dimensions of the mean activation vector    
  m2_hat = torch.mean(m1 ** 2)  

  return m2_hat

def compute_m3_hat(h_vectors_):
  """
  Computes the first raw moment (sample mean) of the diagonal elements of the matrix of second-order moments (auto-correlation matrix).
  Diagonal second-order moments are the (sample) variances of the individual activation features.
  For a matrix H \in R^{N \times d} of N activation vectors h, where h \in R^d :
  m3_hat = \frac{1}{d} \sum_{j=1}^{d} \frac{1}{N} \sum{i=1}^{N} H_{i,j}^2
  :param h_vectors_: the activation vectors
  :return:
  """
  H = h_vectors_.view(size=[h_vectors_.shape[0], -1])
  N = H.shape[0]
  d = H.shape[1]
    
  # m3_hat : third aggregated moment - first-order moment across feature dimensions of the diagonal elements of the auto-correlation matrix
  m2_diag = torch.mean(H ** 2, dim=0)
  m3_hat = torch.mean(m2_diag)
    
  return m3_hat

def compute_m4_hat(h_vectors_, max_block_size=500):
  """
  Computes the first raw moment (sample mean) of the non-diagonal elements of the matrix of second-order moments.
  Non-diagonal second-order moments are the (sample) covariances of the individual activation features.
  For a matrix H \in R^{N \times d} of N activation vectors h, where h \in R^d :
  m4_hat = \frac{1}{d^2 - d} \sum_{k=1}^{d} \sum_{j=1, j \neq k}^{d} \frac{1}{N} \sum{i=1}^{N} H_{i,j} \time H_{i,k}
  :param h_vectors_: the activation vectors
  :param max_block_size: size of smaller blocks to break the matrix multiplication for the covariances
  :return:
  """
  H = h_vectors_.view(size=[h_vectors_.shape[0], -1])
  N = H.shape[0]
  d = H.shape[1]

  m4_hat = 0.0
  num_blocks = d / max_block_size
  for i in range(np.int(np.floor(num_blocks))):
    B = H[:, i * max_block_size:(i + 1) * max_block_size]
    m4_hat = m4_hat + torch.sum(torch.matmul(torch.transpose(B, 0, 1), B))
  if num_blocks % 1 > 0:
    B = H[:, np.int(np.floor(num_blocks)) * max_block_size:]
    m4_hat = m4_hat + torch.sum(torch.matmul(torch.transpose(B, 0, 1), B))
  m4_hat = m4_hat - torch.sum(H ** 2)
  m4_hat = m4_hat / (N * (d * (d - 1)))

  return m4_hat



In [13]:
# Objective function of ABE

In [14]:
def ABE_objective(traj_target, traj_source, t1, t2):
    """
    The objective of ABE. Corresponds to Equation 4 (page 5) of https://arxiv.org/abs/2208.02377
    """
    if t2 - t1 > 0:
        corr, _ = scipy.stats.pearsonr(traj_target[t1:t2], traj_source[t1:t2])
        # The score is the product of the negative (Pearson) correlation and the width of the time interval
        score = - (corr * (t2 - t1 + 1))
    else :
        score = 0.0
    return score

In [None]:
# Finding the critical time by miximizing the objective

In [15]:
def critical_time(task_idx, critical_layer_idx, critical_moment_idx, moments_target, moments_target_iters, moments_source, acc_valid, iters_valid):
    """
    Computes the critical time, and the associated objective value, for a pair of target and source trajectories.
    The critical time simply corresponds to the iteration where the objective is maximal, within the interval [t0, t*_valid]
    """
    # t2 is the iteration of t*_valid (or the end of the experiment)
    max_length = np.min([len(moments_target[critical_moment_idx, :, task_idx, critical_layer_idx]),
                             len(moments_source[critical_moment_idx, :, critical_layer_idx])])
    limit = np.argmin(np.abs(moments_target_iters - iters_valid[np.argmax(acc_valid)])) + 1
    limit = np.min([limit, max_length])
    max_length = limit
    t2 = max_length
    
    traj_target = moments_target[critical_moment_idx, :, task_idx, critical_layer_idx]
    traj_source = moments_source[critical_moment_idx, :, critical_layer_idx]
    
    max_score = - np.infty
    critical_time_ = 0
    t0 = 0
    for t in range(t0, t2):
        t1 = t
        score = ABE_objective(traj_target, traj_source, t1, t2)
        if score > max_score:
            max_score = score
            critical_time_ = t1
            
    return critical_time_, max_score