### **SPRKD (SADDLE POINT RECRUITMENT FOR KNOWLEDGE DISTILLATION): HIGHER-ORDER EXPERIMENTATION AND ANALYSIS.**

In [None]:
#Install necesssary libraries
#Hessian eigenthings and PyHessian
!pip install --upgrade "git+https://github.com/thetechdude124/pytorch-hessian-eigenthings.git@master#egg=hessian-eigenthings"
#Install PyHessian library (personal fork with version-specific updates)
!pip install --upgrade "git+https://github.com/thetechdude124/pyhessian.git@master#egg=pyhessian"
#PIL for image processing
!pip install pillow-simd
#Change numpy to version 1.24.0 (needed for TLI saddle injection)
!pip uninstall numpy
!pip install numpy==1.24.0
#For storage and other dependencies
!pip install h5py

In [None]:
#Import libraries
from pyhessian import hessian
import numpy as np
from hessian_eigenthings import compute_hessian_eigenthings
import torch as t
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms

import fastai
from fastai.vision.all import *
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
%matplotlib inline
#Set GPU device
torch.cuda.set_device(0)
torch.cuda.get_device_name()

In [None]:
#SPRKD class
#Import libraries
import math
#Memory optimizer class
class SPRKD(torch.optim.Optimizer):
    """
    IMPLEMENTATION OF SPRKD - SADDLE POINT RECRUITMENT FOR KNOWLEDGE DISTILLATION.
    
    """
    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
    #Initialize optimizer object
    def __init__(self, params, loss_function, stepsize = 0.001, bias = 0.001, generosity = 5, exploration_steps = 500, 
                is_teacher = False, teacher_saddle_points = [], optimizer = None, epsilon = 10e-3, PGD_delta = 5, 
                max_hessian_neg_eigensteps = 50, PGD_epoch_limit = 100, cooldown_steps = 20, decay = 100, selfKD = True, stride = 1):
        #Check whether parameters are within bounds
        if stepsize <= 0:
            raise ValueError("Invalid stepsize '{}' provided. 'stepsize' must be in the range (0, inf).")
        if (generosity <= 0 or generosity > 10) or isinstance(generosity, float):
            raise ValueError("Invalid generosity score provided. 'generosity' must be an integer in the range [1, 10].")
        if exploration_steps <= 0:
            raise ValueError("Invalid # of steps provided for exploration phase. Must be >= 0.")
        if type(is_teacher) is not bool:
            raise ValueError("Invalid value for 'is_teacher' provided, expected True or False.")
        #Declare DEFAULTS with provided values
        DEFAULTS = dict(loss_function = loss_function, stepsize = stepsize, bias = bias, generosity = generosity, exploration_steps = exploration_steps, 
                        is_teacher = is_teacher, teacher_saddle_points = teacher_saddle_points, optim_function = optimizer, epsilon = epsilon, PGD_delta = PGD_delta,
                        max_hessian_neg_eigensteps = max_hessian_neg_eigensteps, PGD_epoch_limit = PGD_epoch_limit, cooldown_steps = cooldown_steps, decay = decay, 
                        selfKD = selfKD, stride = stride)
        #Initialize optimizer
        super(SPRKD, self).__init__(params, DEFAULTS)

    #Method for optimization step
    def step(self, model, current_loss, n_eigs = 2, closure = None):
        #Set loss to None
        return_loss = None
        #Loss must become closure function if it is defined, else remain the same
        return_loss = closure() if closure != None else return_loss
        #Check if we are on the first iteration of the optimizer - if so, setup variables 
        if not self.state["STEP"]:
                #Set step to 1
                self.state["STEP"] = 1
                #Create list to hold saddle points and model parameters at those points
                self.state["SADDLE_POINT_PARAMS"] = []
                self.state["PHASE"] = "EXPLORATORY"
                self.state["GRADIENTS"] = {}
                #Define cooldown period for saddle point reversions
                self.state["COOLDOWN_STEPS"] = self.param_groups[0]["cooldown_steps"]
                #Create dictionary of booleans to track whether each parameter matrix should be allowed to converge to the teacher saddle point region 
                self.state["ALLOW_TEACHER_SADDLE_STEPS"] = {}
                #Create counter to track the number of Hessian negative eigensteps taken
                self.state["N_HESSIAN_NEG_EIGENSTEPS"] = 0
                #Create a norm threshold for the given gradients - minimum value before which gradients likely reflect saddle point plateaus
                self.state["GRADIENT_THRESHOLD"] = torch.tensor(0.01)
                #Add history of parameter values for Peterubed Gradient Descent Mechanics
                self.state["PARAM_HISTORY_PGD"] = {}
                #Save loss for same mechanics (needed to evaluate whether the Perterbation was successful or if reversion is necessary)
                self.state["STORED_LOSS"] = 0.0
        #If we are not on the first iteration, increment the current step and set loss 
        else: 
            self.state["STEP"] += 1
        for param_group in self.param_groups:
            #heck to see if any parameters are outside the epsillon delta range of the saddle point region
            #Do so only if the current model is not a teacher
            if param_group["is_teacher"]: param_group["optim_function"].step()
            #The model is otherwise a student - take a step if approximated saddle region reached
            elif True in self.state["ALLOW_TEACHER_SADDLE_STEPS"].values(): print("\nDisabled. No ADAM step.")
                # print(self.state["ALLOW_TEACHER_SADDLE_STEPS"])
            #If True is not in the dictionary, it must be set to false -> take a step
            else: param_group["optim_function"].step(), print(self.state["ALLOW_TEACHER_SADDLE_STEPS"])
            #Check if the model being optimized is a student or teacher
            #If teacher, detect saddle points via Hessian eigenvalues approximation and density
            if param_group["is_teacher"]: self.determineSaddlePoint(model = model, n_eigs = n_eigs, param_group = param_group, step_delta = 100)
            #If student, apply Transformation Matrix, Negative Hessian Eigensteps, and/or PGD steps
            elif not param_group["is_teacher"]:
                #Determine if TM application is needed
                self.applyTransformationMatrix(param_group = param_group)
                #Conditionally apply Negative Hessian Eigensteps
                self.negativeHessianEigensteps(model = model, n_eigs = n_eigs, param_group = param_group)
                #Determine PGD necessity and apply accordingly
                self.perturbedGD(model = model, current_loss = current_loss, param_group = param_group)
                #Decrement cooldown steps if not already zero
                if self.state["COOLDOWN_STEPS"] > 0: self.state["COOLDOWN_STEPS"] -= 1
        return return_loss

    #Determine if a teacher model is at a saddle point via Hessian eigenvalue directional density
    def determineSaddlePoint(self, model, n_eigs, param_group, step_delta = 100):
        #Approximate the eigenvalues of the Hessian matrix and find determinant + trace.
        #Create Hessian computation object
        hess_comp = hessian(model = model.model, criterion = model.loss_func, data = next(iter(model.dls.train)), cuda = True)
        #Find top two eigenvalues/eigenvectors
        top_eigenvalues, top_eigenvectors = hess_comp.eigenvalues(top_n = n_eigs)
        #Get # of positive and negative eigenvalues
        pos_eigs, neg_eigs, zero_eigs = [], [], []
        for eigenvalue in top_eigenvalues: 
            if eigenvalue > 0: pos_eigs.append(eigenvalue)
            elif eigenvalue == 0: zero_eigs.append(eigenvalue) 
            else: neg_eigs.append(eigenvalue)
        #Compare total magnitude of positive and negative directions - if the magnitude of negative directions is at least 40% of positive direction magntidue, append saddle point
        print("\nNEG EIG SUM:", abs(sum(neg_eigs)))
        print("POS EIGE SUM:", sum(pos_eigs))
        print("TOP EIGENVALUES:", top_eigenvalues)
        if abs(sum(neg_eigs)) >= (0.4 * sum(pos_eigs)):
            #Clone and detach parameters before appending (otherwise, all parameter matrices become linked together and thereby become identical)
            #Iterate over each parameter, detach, and store in a new list object
            detached_saddle_point = [parameter.clone().detach().to('cpu') for parameter in param_group["params"]]
            self.state["SADDLE_POINT_PARAMS"].append(detached_saddle_point), print('APPENDED SADDLE POINTS.', end = "")

    #Apply Transformation Matrix to guide student within a specified epsilon-delta of the approximated saddle region
    def applyTransformationMatrix(self, param_group):
        #Calculate average norm between approx. saddle region and current parameters
        total_norm = 0.0
        for param, saddle_point in zip(param_group["params"], param_group["teacher_saddle_points"]):
            norm_diff = torch.abs(torch.linalg.norm(param) - torch.linalg.norm(saddle_point))
            total_norm += norm_diff
        #Find average norm
        average_norm = total_norm/len(param_group["params"])
        print("AVERAGE NORM FROM APPROX. SADDLE REGION:", average_norm)
        #Iterate over all parameters, determine if TM is needed, apply accordingly
        for (param_number, param), (saddle_number, best_saddle_point) in zip(enumerate(param_group["params"]), enumerate(param_group["teacher_saddle_points"])):
            #If this is the first step, initialize all saddle point seeking behaviours for the parameter to True
            if self.state["STEP"] == 1: 
                self.state["ALLOW_TEACHER_SADDLE_STEPS"][param_number] = True
            #Find Euclidean Distance Matrix between current parameters and saddle point
            #If matrices are one dimensional, unsqueeze 
            #Only check DIAGONAL entries as opposed to all (diagonal entries should be identical if matrices are identical)
            if len(param.shape) < 2: euclidean_distance = torch.diagonal(torch.cdist(param.unsqueeze(1), best_saddle_point.unsqueeze(1)))
            #If the matrix is exactly two-dimensional, simply take the diagonal without unsqueezing
            elif len(param.shape) == 2: euclidean_distance = torch.diagonal(torch.cdist(param, best_saddle_point))
            #For matrices higher than 2 dimensions
            #Take the diagonal of the diagonal matrix as this is a multidimensional tensor (the desired values are diag^2)
            else: euclidean_distance = torch.diagonal(torch.diagonal(torch.cdist(param, best_saddle_point)))
            #Check if all elements are below epsilon - first label the matrix by True or False depending on if they meet this condition
            #False - below epsilon. True - above epsilon.
            euclidean_distance_bool = euclidean_distance > param_group["epsilon"]
            #If any 'True' elements are present, continue taking steps towards the teacher saddle point region
            if torch.any(euclidean_distance_bool).item() and self.state["ALLOW_TEACHER_SADDLE_STEPS"][param_number]: 
                #Build elementwise transformation matrix between current parameters and saddle point parameters
                #Divide matrices
                transform_ewise_mat = torch.div(best_saddle_point, param)
                #Multiply by current parameter matrix 
                #Add exponentially decaying weight to prevent long-term convergence hindering
                weight = -2**(-self.state["STEP"]/10)/2 + 1
                #Apply TM * weight to transformation matrix
                param.data = param.squeeze().mul(weight * transform_ewise_mat)
                print('\nTRANSFORMATION [{}] STEP TAKEN.'.format(param_number))
                print('\nTRANSFORMATION [{}] MATRIX NORM: {}'.format(param_number, torch.linalg.norm(transform_ewise_mat)))
                print('\nEUCLIDEAN DIST. [{}] MATRIX: {}'.format(param_number, euclidean_distance))
                print('\nEUCLIDEAN DIST. [{}] MATRIX NORM: {}'.format(param_number, torch.linalg.norm(euclidean_distance)))
            #Otherwise, if this parameter is in epsilon range of the saddle point, disable taking targeted steps to this region for said parameter
            else: self.state["ALLOW_TEACHER_SADDLE_STEPS"][param_number] = False
    
    #Take Negative Hessian Eigensteps for a given step delta after the student is sufficiently close to the approximated saddle point region
    def negativeHessianEigensteps(self, model, n_eigs, param_group):
        #Check if all parameters are sufficiently close to the approx. saddle region and if within the step delta
        #End if not true
        if True in self.state["ALLOW_TEACHER_SADDLE_STEPS"].values() or self.state["N_HESSIAN_NEG_EIGENSTEPS"] > param_group["max_hessian_neg_eigensteps"]: return
        #Otherwise, begin taking negative Hessian eigenstep
        #Iterate and save gradients (will otherwise be distrupted by Hessian eigenvalue computations)
        saved_grads = {}
        for param_n, param in enumerate(param_group["params"]): saved_grads[param_n] = param.grad.clone().detach()
        #Calculate top 2 Hessian eigenvalues and eigenvectors
        hess_comp = hessian(model = model.model, criterion = model.loss_func, data = next(iter(model.dls.train)), cuda = True)
        top_eigenvalues, top_eigenvectors = hess_comp.eigenvalues(top_n = n_eigs)
        #Find position of smallest eigenvalue (largest negative eigenvalue) and the associated eigenvector
        ev_index = top_eigenvalues.index(min(top_eigenvalues))
        largest_negative_eigenvalue = torch.tensor(top_eigenvalues[ev_index])
        largest_negative_eigenvector = top_eigenvectors[ev_index]
        #Iterate over each parameter and multiply (rescale) the gradients by the eigenvector corresponding to the largest negative direction (smallest eigenvalue)
        for (param_n, param), (eigenvector_n, eigenvector)in zip(enumerate(param_group["params"]), enumerate(largest_negative_eigenvector)):
            #Convert to PyTorch tensor
            eigenvector = torch.tensor(eigenvector)
            #Break if the eigenvalue is negative or 0
            if largest_negative_eigenvalue >= 0: break
            #Skip if the parameter has no gradient
            if saved_grads[param_n] == None: continue
            #Otherwise, make a copy of the gradient and reshape it to a vector
            vec_grad = saved_grads[param_n]
            #Use the formula p_t - [g(x)^T*v]*v to transform the parameters along the respective eigendirection (where v is the eigenvector)
            #
            neg_eigenstep = vec_grad.mul(eigenvector).mul(eigenvector)
            #Reshape the eigenstep vector back into the size of the original matrix
            #neg_eigenstep = torch.reshape(neg_eigenstep, param.grad.shape)
            #Update the parameters by the learning rate * the negative eigenstep
            #Compute weight
            weight = (1/torch.sqrt(abs(largest_negative_eigenvalue))) * (2**(-(1/(param_group["max_hessian_neg_eigensteps"]/1.5)) * self.state["STEP"]))
            param.data = param.data - (1 * neg_eigenstep)
        #Increase counter of Hessian eigensteps taken.
        self.state["N_HESSIAN_NEG_EIGENSTEPS"] += 1
        print("TAKEN NEGATIVE EIGENSTEP. TOP EIGENVALUES:", top_eigenvalues)
    
    #Perturbed Gradient Descent/Negative Hessian Eigensteps for efficient saddle point escaping
    def perturbedGD(self, model, current_loss, param_group):
        #As per Perturbated Gradient Descent (PGD), add a small value to the current parameters if this is the case
        for param_idx, param in enumerate(param_group["params"]):
            #Compare norm to gradient threshold
            #Also check if the average norm between the aproximated saddle region and the model's current location in parameter space exceeds the PGD delta
            #Determine whether SPRKD is within the PGD epoch limit
            if torch.linalg.norm(param.grad) < self.state["GRADIENT_THRESHOLD"] and avg_norm > param_group["PGD_delta"] \
                and self.state["STEP"]/len(model.dls.train) < param_group["PGD_epoch_limit"]:
                #If larger, add the perturbation
                #Do so only if the cooldown steps have been fulfilled (at least 50)
                #And, check that the saddle point region has already been reached (avoiding interference with transformation matrix computations)
                if self.state["COOLDOWN_STEPS"] == 0 and True not in self.state["ALLOW_TEACHER_SADDLE_STEPS"].values():
                    #If this is the first iteration or the loss trheshold has been met, simply perturb, save loss, and save parameters 
                    if self.state["STORED_LOSS"] == 0.0 or self.state["STORED_LOSS"] - current_loss >= 0.002: 
                        #Save current parameter to parameter history
                        self.state["PARAM_HISTORY_PGD"][param_idx] = param.clone().detach()
                        #Perturbate parameters - add Gausian noise with variance 0.1
                        param.data = param + ((math.sqrt(0.1)) * torch.rand_like(param))
                        #Set loss
                        self.state["STORED_LOSS"] = current_loss
                        print("Perturbed to avoid saddle point stagnation.")
                    #If the condition has not been met, revert to the point before the most recent perturbation
                    elif self.state["STORED_LOSS"] - current_loss < 0.002:
                        param.data = self.state["PARAM_HISTORY_PGD"][param_idx] = param.clone().detach()
                        print("Perturbation ineffective [Reduction: {}]. Reverted to non-perturbed point.".format(self.state["STORED_LOSS"] - current_loss))
                    self.state["COOLDOWN_STEPS"] = param_group["cooldown_steps"]
