### **SPRKD (SADDLE POINT RECRUITMENT FOR KNOWLEDGE DISTILLATION): HIGHER-ORDER EXPERIMENTATION AND ANALYSIS.**

In [1]:
#Install necesssary libraries
#Hessian eigenthings and PyHessian
!pip install --upgrade "git+https://github.com/thetechdude124/pytorch-hessian-eigenthings.git@master#egg=hessian-eigenthings"
#Install PyHessian library (personal fork with version-specific updates)
!pip install --upgrade "git+https://github.com/thetechdude124/pyhessian.git@master#egg=pyhessian"
#PIL for image processing
!pip install pillow-simd
#Change numpy to version 1.24.0 (needed for TLI saddle injection)
!pip uninstall --yes numpy
!pip install --yes numpy==1.24.0
#For storage and other dependencies
!pip install h5py

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting hessian-eigenthings
  Cloning https://github.com/thetechdude124/pytorch-hessian-eigenthings.git (to revision master) to /tmp/pip-install-8brtsz9y/hessian-eigenthings_d4d4bb954eba4e8fba4a2e76966ffe93
  Running command git clone --filter=blob:none --quiet https://github.com/thetechdude124/pytorch-hessian-eigenthings.git /tmp/pip-install-8brtsz9y/hessian-eigenthings_d4d4bb954eba4e8fba4a2e76966ffe93
  Resolved https://github.com/thetechdude124/pytorch-hessian-eigenthings.git to commit bb2596e7db127fdfafee7d05a8056f7bb046882f
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting numpy>=0.14
  Using cached numpy-1.22.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (16.9 MB)
Installing collected packages: numpy
  Attempting uninstall: numpy
    Found existing installation: numpy 1.24.2
    Uninstalling numpy-1.24.2:
      Successfully uninstalled numpy-1.24.2
Succ

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting pyhessian
  Cloning https://github.com/thetechdude124/pyhessian.git (to revision master) to /tmp/pip-install-p75cjwb6/pyhessian_e27514d28f334eefaf4e30099b0e3cf1
  Running command git clone --filter=blob:none --quiet https://github.com/thetechdude124/pyhessian.git /tmp/pip-install-p75cjwb6/pyhessian_e27514d28f334eefaf4e30099b0e3cf1
  Resolved https://github.com/thetechdude124/pyhessian.git to commit a6b86168748011858c5c0701f49121a10f78098a
  Preparing metadata (setup.py) ... [?25l[?25hdone
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Found existing installation: numpy 1.22.4
Uninstalling numpy-1.22.4:
  Successfully uninstalled numpy-1.22.4

Usage:   
  pip3 install [options] <requirement specifier> [package-index-options] ...
  pip3 install [options] -r <requirements file> [package-index-options] ...
  pip3 install [option

In [2]:
#Import libraries
from pyhessian import hessian
import numpy as np
from hessian_eigenthings import compute_hessian_eigenthings
import torch as t
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms

import fastai
from fastai.vision.all import *
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
%matplotlib inline
#Set GPU device
torch.cuda.set_device(0)
torch.cuda.get_device_name()



'Tesla T4'

In [4]:
#Malaria dataset loader
#Define transforms (resize images to 32 x 32 with 3 channels)
reshape_size = 32
transform = transforms.Compose([transforms.Resize((reshape_size, reshape_size)),
                                transforms.ToTensor()
                                #transforms.RandomRotation(20),
                                #transforms.RandomHorizontalFlip(),
                                #transforms.RandomAffine(0.05)
                                ])
malaria_dataset = datasets.ImageFolder(r'./cell_images/', transform = transform)
#Split dataset into training and validation
train_len = int(round(len(malaria_dataset) * 0.75, 0))
valid_len = int(round(len(malaria_dataset) * 0.25, 0))
train_set, valid_set = torch.utils.data.random_split(malaria_dataset, [train_len, valid_len])
#Create train and validation dataloaders (use GPU via pin_memory)
batch_size = 64
train_loader = torch.utils.data.DataLoader(train_set, batch_size = batch_size, shuffle = True, pin_memory = True, num_workers = 2)
valid_loader = torch.utils.data.DataLoader(valid_set, batch_size = batch_size, shuffle = True, pin_memory = True, num_workers = 2)
#Shift dataloaders to fastai for use with learner objects
MALARIA_TRAIN_DATALOADER = DataLoaders(train_loader, valid_loader)
#print(next(iter(valid_loader))[1])

In [8]:
#SPRKD class
#Import libraries
import math
#Memory optimizer class
class SPRKD(torch.optim.Optimizer):
    """
    IMPLEMENTATION OF SPRKD - SADDLE POINT RECRUITMENT FOR KNOWLEDGE DISTILLATION.
    
    """
    #Initialize optimizer object
    def __init__(self, params, loss_function, stepsize = 0.001, bias = 0.001, generosity = 5, saddle_steps = 500, 
                is_teacher = False, is_control = False, teacher_saddle_points = [], optimizer = None, epsilon = 10e-3, PGD_delta = 5, 
                PGD_epoch_limit = 100, PGD_grad_threshold = 0.01, max_hessian_neg_eigensteps = 50, cooldown_steps = 20, decay = 100, selfKD = True, stride = 1):
        #Check whether parameters are within bounds
        if stepsize <= 0:
            raise ValueError("Invalid stepsize '{}' provided. 'stepsize' must be in the range (0, inf).")
        if (generosity <= 0 or generosity > 10) or isinstance(generosity, float):
            raise ValueError("Invalid generosity score provided. 'generosity' must be an integer in the range [1, 10].")
        if saddle_steps <= 0:
            raise ValueError("Invalid # of steps provided for saddle point checking. Must be >= 0.")
        if type(is_teacher) is not bool:
            raise ValueError("Invalid value for 'is_teacher' provided, expected True or False.")
        #Declare DEFAULTS with provided values
        DEFAULTS = dict(loss_function = loss_function, stepsize = stepsize, bias = bias, generosity = generosity, saddle_steps = saddle_steps, 
                        is_teacher = is_teacher, is_control = is_control, teacher_saddle_points = teacher_saddle_points, optim_function = optimizer, epsilon = epsilon, 
                        PGD_delta = PGD_delta, PGD_epoch_limit = PGD_epoch_limit, PGD_grad_threshold = PGD_grad_threshold, max_hessian_neg_eigensteps = max_hessian_neg_eigensteps,
                        cooldown_steps = cooldown_steps, decay = decay, selfKD = selfKD, stride = stride)
        #Initialize optimizer
        super(SPRKD, self).__init__(params, DEFAULTS)

    #Method for optimization step
    def step(self, model, current_loss, n_eigs = 2, closure = None):
        #Set loss to None
        return_loss = None
        #Loss must become closure function if it is defined, else remain the same
        return_loss = closure() if closure != None else return_loss
        #Check if we are on the first iteration of the optimizer - if so, setup variables 
        if not self.state["STEP"]:
                #Set step to 1
                self.state["STEP"] = 1
                #Create list to hold saddle points and model parameters at those points
                self.state["SADDLE_POINT_PARAMS"] = []
                self.state["PHASE"] = "EXPLORATORY"
                self.state["GRADIENTS"] = {}
                #Define cooldown period for saddle point reversions
                self.state["COOLDOWN_STEPS"] = self.param_groups[0]["cooldown_steps"]
                #Create dictionary of booleans to track whether each parameter matrix should be allowed to converge to the teacher saddle point region 
                self.state["ALLOW_TEACHER_SADDLE_STEPS"] = {}
                #Create counter to track the number of Hessian negative eigensteps taken
                self.state["N_HESSIAN_NEG_EIGENSTEPS"] = 0
                #Add history of parameter values for Peterubed Gradient Descent Mechanics
                self.state["PARAM_HISTORY_PGD"] = {}
                #Save loss for same mechanics (needed to evaluate whether the Perterbation was successful or if reversion is necessary)
                self.state["STORED_LOSS"] = 0.0
        #If we are not on the first iteration, increment the current step and set loss 
        else: 
            self.state["STEP"] += 1
        #Print current step
        print('STEP ' + str(self.state["STEP"]), end = "\n")
        #Iterate over parameter groups
        for param_group in self.param_groups:
            #If the model is control, simply take a step
            if param_group["is_control"]: 
                param_group["optim_function"].step()
                continue
            #Check to see if any parameters are outside the epsillon delta range of the saddle point region
            #Do so only if the current model is not a teacher
            if param_group["is_teacher"]: param_group["optim_function"].step()
            #The model is otherwise a student - take a step if approximated saddle region reached
            elif True in self.state["ALLOW_TEACHER_SADDLE_STEPS"].values(): print("\nDisabled. No ADAM step.")
                # print(self.state["ALLOW_TEACHER_SADDLE_STEPS"])
            #If True is not in the dictionary, it must be set to false -> take a step
            else: param_group["optim_function"].step(), print(self.state["ALLOW_TEACHER_SADDLE_STEPS"])
            #Check if the model being optimized is a student or teacher
            #If teacher, detect saddle points via Hessian eigenvalues approximation and density
            #Check for saddle points every saddle_steps (user_defined)
            if param_group["is_teacher"] and self.state["STEP"] % param_group["saddle_steps"] == 0: 
                print("DETERMINING SADDLE POINT PRESENCE.")
                self.determineSaddlePoint(model = model, n_eigs = n_eigs, param_group = param_group)
            #If student, apply Transformation Matrix, Negative Hessian Eigensteps, and/or PGD steps
            elif not param_group["is_teacher"]:
                #Calculate average norm between approx. saddle region and current parameters
                #Needed for both TM application and PGD
                total_norm = 0.0
                for param, saddle_point in zip(param_group["params"], param_group["teacher_saddle_points"]):
                    norm_diff = torch.abs(torch.linalg.norm(param) - torch.linalg.norm(saddle_point))
                    total_norm += norm_diff
                #Find average norm
                average_norm = total_norm/len(param_group["params"])
                print("AVERAGE NORM FROM APPROX. SADDLE REGION:", average_norm)
                #Determine if TM application is needed
                self.applyTransformationMatrix(average_norm = average_norm, param_group = param_group)
                #Conditionally apply Negative Hessian Eigensteps
                self.negativeHessianEigensteps(model = model, n_eigs = n_eigs, param_group = param_group)
                #Determine PGD necessity and apply accordingly
                self.perturbedGD(model = model, average_norm = average_norm, current_loss = current_loss, param_group = param_group)
                #Decrement cooldown steps if not already zero
                if self.state["COOLDOWN_STEPS"] > 0: self.state["COOLDOWN_STEPS"] -= 1
        return return_loss

    #Determine if a teacher model is at a saddle point via Hessian eigenvalue directional density
    def determineSaddlePoint(self, model, n_eigs, param_group):
        #Approximate the eigenvalues of the Hessian matrix and find determinant + trace.
        #Create Hessian computation object
        hess_comp = hessian(model = model.model, criterion = model.loss_func, data = next(iter(model.dls.train)), cuda = True)
        #Find top two eigenvalues/eigenvectors
        top_eigenvalues, top_eigenvectors = hess_comp.eigenvalues(top_n = n_eigs)
        #Get # of positive and negative eigenvalues
        pos_eigs, neg_eigs, zero_eigs = [], [], []
        for eigenvalue in top_eigenvalues: 
            if eigenvalue > 0: pos_eigs.append(eigenvalue)
            elif eigenvalue == 0: zero_eigs.append(eigenvalue) 
            else: neg_eigs.append(eigenvalue)
        #Compare total magnitude of positive and negative directions - if the magnitude of negative directions is at least 40% of positive direction magntidue, append saddle point
        print("NEG EIG SUM:", abs(sum(neg_eigs)))
        print("POS EIG SUM:", sum(pos_eigs))
        print("TOP EIGENVALUES:", top_eigenvalues)
        if abs(sum(neg_eigs)) >= (0.4 * sum(pos_eigs)):
            #Clone and detach parameters before appending (otherwise, all parameter matrices become linked together and thereby become identical)
            #Iterate over each parameter, detach, and store in a new list object
            detached_saddle_point = [parameter.clone().detach().to('cpu') for parameter in param_group["params"]]
            self.state["SADDLE_POINT_PARAMS"].append(detached_saddle_point), print("APPENDED SADDLE POINTS.")

    #Apply Transformation Matrix to guide student within a specified epsilon-delta of the approximated saddle region
    def applyTransformationMatrix(self, average_norm, param_group):
        #Iterate over all parameters, determine if TM is needed, apply accordingly
        for (param_number, param), (saddle_number, best_saddle_point) in zip(enumerate(param_group["params"]), enumerate(param_group["teacher_saddle_points"])):
            #If this is the first step, initialize all saddle point seeking behaviours for the parameter to True
            if self.state["STEP"] == 1: 
                self.state["ALLOW_TEACHER_SADDLE_STEPS"][param_number] = True
            #Find Euclidean Distance Matrix between current parameters and saddle point
            #If matrices are one dimensional, unsqueeze 
            #Only check DIAGONAL entries as opposed to all (diagonal entries should be identical if matrices are identical)
            if len(param.shape) < 2: euclidean_distance = torch.diagonal(torch.cdist(param.unsqueeze(1), best_saddle_point.unsqueeze(1)))
            #If the matrix is exactly two-dimensional, simply take the diagonal without unsqueezing
            elif len(param.shape) == 2: euclidean_distance = torch.diagonal(torch.cdist(param, best_saddle_point))
            #For matrices higher than 2 dimensions
            #Take the diagonal of the diagonal matrix as this is a multidimensional tensor (the desired values are diag^2)
            else: euclidean_distance = torch.diagonal(torch.diagonal(torch.cdist(param, best_saddle_point)))
            #Check if all elements are below epsilon - first label the matrix by True or False depending on if they meet this condition
            #False - below epsilon. True - above epsilon.
            euclidean_distance_bool = euclidean_distance > param_group["epsilon"]
            #If any 'True' elements are present, continue taking steps towards the teacher saddle point region
            if torch.any(euclidean_distance_bool).item() and self.state["ALLOW_TEACHER_SADDLE_STEPS"][param_number]: 
                #Build elementwise transformation matrix between current parameters and saddle point parameters
                #Divide matrices
                transform_ewise_mat = torch.div(best_saddle_point, param)
                #Multiply by current parameter matrix 
                #Add exponentially decaying weight to prevent long-term convergence hindering
                weight = -2**(-self.state["STEP"]/10)/2 + 1
                #Apply TM * weight to transformation matrix
                param.data = param.squeeze().mul(weight * transform_ewise_mat)
                print('\nTRANSFORMATION [{}] STEP TAKEN.'.format(param_number))
                print('\nTRANSFORMATION [{}] MATRIX NORM: {}'.format(param_number, torch.linalg.norm(transform_ewise_mat)))
                print('\nEUCLIDEAN DIST. [{}] MATRIX: {}'.format(param_number, euclidean_distance))
                print('\nEUCLIDEAN DIST. [{}] MATRIX NORM: {}'.format(param_number, torch.linalg.norm(euclidean_distance)))
            #Otherwise, if this parameter is in epsilon range of the saddle point, disable taking targeted steps to this region for said parameter
            else: self.state["ALLOW_TEACHER_SADDLE_STEPS"][param_number] = False
    
    #Take Negative Hessian Eigensteps for a given step delta after the student is sufficiently close to the approximated saddle point region
    def negativeHessianEigensteps(self, model, n_eigs, param_group):
        #Check if all parameters are sufficiently close to the approx. saddle region and if within the step delta
        #End if not true
        if True in self.state["ALLOW_TEACHER_SADDLE_STEPS"].values() or self.state["N_HESSIAN_NEG_EIGENSTEPS"] > param_group["max_hessian_neg_eigensteps"]: return
        #Otherwise, begin taking negative Hessian eigenstep
        #Iterate and save gradients (will otherwise be distrupted by Hessian eigenvalue computations)
        saved_grads = {}
        for param_n, param in enumerate(param_group["params"]): saved_grads[param_n] = param.grad.clone().detach()
        #Calculate top 2 Hessian eigenvalues and eigenvectors
        hess_comp = hessian(model = model.model, criterion = model.loss_func, data = next(iter(model.dls.train)), cuda = True)
        top_eigenvalues, top_eigenvectors = hess_comp.eigenvalues(top_n = n_eigs)
        #Find position of smallest eigenvalue (largest negative eigenvalue) and the associated eigenvector
        ev_index = top_eigenvalues.index(min(top_eigenvalues))
        largest_negative_eigenvalue = torch.tensor(top_eigenvalues[ev_index])
        largest_negative_eigenvector = top_eigenvectors[ev_index]
        #Iterate over each parameter and multiply (rescale) the gradients by the eigenvector corresponding to the largest negative direction (smallest eigenvalue)
        for (param_n, param), (eigenvector_n, eigenvector)in zip(enumerate(param_group["params"]), enumerate(largest_negative_eigenvector)):
            #Convert to PyTorch tensor
            eigenvector = torch.tensor(eigenvector)
            #Break if the eigenvalue is negative or 0
            if largest_negative_eigenvalue >= 0: break
            #Skip if the parameter has no gradient
            if saved_grads[param_n] == None: continue
            #Otherwise, make a copy of the gradient and reshape it to a vector
            vec_grad = saved_grads[param_n]
            #Use the formula p_t - [g(x)^T*v]*v to transform the parameters along the respective eigendirection (where v is the eigenvector)
            #Compute negative Hessian eigenstep
            neg_eigenstep = vec_grad.mul(eigenvector).mul(eigenvector)
            #Reshape the eigenstep vector back into the size of the original matrix
            #neg_eigenstep = torch.reshape(neg_eigenstep, param.grad.shape)
            #Update the parameters by the learning rate * the negative eigenstep
            #Compute weight
            weight = 2**(-(1/(param_group["max_hessian_neg_eigensteps"]/1.5)) * self.state["STEP"])
            param.data = param.data - (0.1 * weight * neg_eigenstep)
        #Increase counter of Hessian eigensteps taken.
        self.state["N_HESSIAN_NEG_EIGENSTEPS"] += 1
        print("TAKEN NEGATIVE EIGENSTEP. TOP EIGENVALUES:", top_eigenvalues)
    
    #Perturbed Gradient Descent/Negative Hessian Eigensteps for efficient saddle point escaping
    def perturbedGD(self, model, average_norm, current_loss, param_group):
        #As per Perturbated Gradient Descent (PGD), add a small value to the current parameters if this is the case
        for param_idx, param in enumerate(param_group["params"]):
            #Compare norm to gradient threshold
            #Also check if the average norm between the aproximated saddle region and the model's current location in parameter space exceeds the PGD delta
            #Determine whether SPRKD is within the PGD epoch limit
            if torch.linalg.norm(param.grad) < param_group["PGD_grad_threshold"] and average_norm > param_group["PGD_delta"] \
                and self.state["STEP"]/len(model.dls.train) < param_group["PGD_epoch_limit"]:
                #If larger, add the perturbation
                #Do so only if the cooldown steps have been fulfilled (at least 50)
                #And, check that the saddle point region has already been reached (avoiding interference with transformation matrix computations)
                if self.state["COOLDOWN_STEPS"] == 0 and True not in self.state["ALLOW_TEACHER_SADDLE_STEPS"].values():
                    #If this is the first iteration or the loss trheshold has been met, simply perturb, save loss, and save parameters 
                    if self.state["STORED_LOSS"] == 0.0 or self.state["STORED_LOSS"] - current_loss >= 0.002: 
                        #Save current parameter to parameter history
                        self.state["PARAM_HISTORY_PGD"][param_idx] = param.clone().detach()
                        #Perturbate parameters - add Gausian noise with variance 0.1
                        param.data = param + ((math.sqrt(0.1)) * torch.rand_like(param))
                        #Set loss
                        self.state["STORED_LOSS"] = current_loss
                        print("Perturbed to avoid saddle point stagnation.")
                    #If the condition has not been met, revert to the point before the most recent perturbation
                    elif self.state["STORED_LOSS"] - current_loss < 0.002:
                        param.data = self.state["PARAM_HISTORY_PGD"][param_idx] = param.clone().detach()
                        print("Perturbation ineffective [Reduction: {}]. Reverted to non-perturbed point.".format(self.state["STORED_LOSS"] - current_loss))
                    self.state["COOLDOWN_STEPS"] = param_group["cooldown_steps"]


In [16]:
#Define template architecture for all teacher and student models (performing self-distillation)
def defineModels():
    #Define teacher model
    globals()["MALARIA_CNN_ARCHITECTURE"] = nn.Sequential(
        nn.Conv2d(3, 4, kernel_size = (3, 3)),
        nn.ReLU(),
        nn.Conv2d(4, 8, kernel_size = (3, 3)),
        nn.ReLU(),
        nn.MaxPool2d((2, 2)),
        nn.Dropout(0.1),
        nn.Flatten(),
        nn.Linear(1568, 16),
        nn.ReLU(),
        nn.Dropout(0.2),
        nn.Linear(16, 2),
        nn.Softmax(dim = 1)
    )
    #Define student model
    globals()["MALARIA_CNN_KD_ARCHITECTURE_WR"] = nn.Sequential(
        nn.Conv2d(3, 2, kernel_size = (3, 3)),
        nn.ReLU(),
        nn.Conv2d(2, 4, kernel_size = (3, 3)),
        nn.ReLU(),
        nn.MaxPool2d((2, 2)),
        nn.Dropout(0.1),
        nn.Flatten(),
        nn.Linear(784, 8),
        nn.ReLU(),
        nn.Dropout(0.2),
        nn.Linear(8, 2),
        nn.Softmax(dim = 1)
    )
#Define models
defineModels()
#Define train teachers functionality
#Train one teacher at a time - multiple simultaneously extends maximum possible runtime with high chance of losing existing tracked saddle points
def train_teachers(n_teacher, n_epochs, n_samples_per_epoch, architecture, experiment, loss_function, saddle_steps, saddle_points):
    #Check to see what dataset is being used
    dataloader = globals()[experiment + "_TRAIN_DATALOADER"]
    #Create learner object
    teacher_learner = globals()["TEACHER_" + str(n_teacher + 1) + "_" + experiment] = Learner(dataloader, deepcopy(architecture), metrics = "accuracy", loss_func = loss_function)
    teacher_learner.model = teacher_learner.model.to('cuda')
    #Find optimal learning rate
    #lr = globals()["TEACHER_" + str(teacher) + "_" + experiment + "_optimalLR"] = teacher_learner.lr_find()[0]
    lr = 0.001
    #Set loss function
    teacher_learner.loss_func = loss_function
    #Set Adam optimizer to take the actual steps
    ADAM_opt = Adam(teacher_learner.parameters(), lr = lr)
    #Set SPRKD for saddle point tracking and reversions
    optimizer = globals()["TEACHER_" + str(n_teacher + 1) + "_" + experiment].opt = SPRKD(params = teacher_learner.parameters(), loss_function = loss_function, 
                                                                                          stepsize = lr, bias = 0.001, generosity = 5, saddle_steps = saddle_steps, 
                                                                                          is_teacher = True, optimizer = ADAM_opt)
    #Declare loss and accuracy arrays for training and validation
    train_losses = globals()["TEACHER_" + str(n_teacher) + "_" + experiment + "_TRAIN_LOSSES"] = []
    train_accuracies = globals()["TEACHER_" + str(n_teacher) + "_" + experiment + "_TRAIN_ACCURACIES"] = []
    valid_losses = globals()["TEACHER_" + str(n_teacher) + "_" + experiment + "_VALID_LOSSES"] = []
    valid_accuracies = globals()["TEACHER_" + str(n_teacher) + "_" + experiment + "_VALID_ACCURACIES"] = []
    #Print current teacher
    print('========== TEACHER [{}] TRAINING =========='.format(n_teacher + 1))
    #Iterate over epochs
    for epoch in range(n_epochs):
        #Training loop
        perform_model_train_loop(learner = teacher_learner, n_samples_per_epoch = n_samples_per_epoch, 
                            losses = train_losses, accuracies = train_accuracies, epoch = epoch, epochs = n_epochs)
        #Validation loop
        perform_model_valid_loop(learner = teacher_learner, n_samples_per_epoch = None, 
                            losses = valid_losses, accuracies = valid_accuracies, epoch = epoch, epochs = n_epochs)
    #Get key saddle points from teacher and store (discard the 2D array and "params" key, only the end list of parameters is needed)
    saddle_points[n_teacher] = optimizer.state_dict()["state"]["SADDLE_POINT_PARAMS"]
    #Return training complete message
    return "TRAINING COMPLETE."

#Define model training functionality (needed as we are using a custom optimizer)
def perform_model_train_loop(learner, n_samples_per_epoch, losses, accuracies, epoch, epochs):
    #Loss and accuracy metrics
    epoch_loss = 0.0
    epoch_accuracies = 0.0
    #Iterate over dataloader
    for batch_index, batch_data in enumerate(learner.dls.train, 0):
        #Get inputs and labels
        inputs, labels = batch_data
        #Move to GPU
        inputs, labels = inputs.to('cuda'), labels.to('cuda')
        #Set gradients to zero
        learner.zero_grad()
        #Predict based on inputs (find argmax labels)
        preds = learner.model(inputs)
        predicted_labels = torch.argmax(preds, dim = 1)
        #Find losimply bess from true labels - convert both inputs and labels to tensors first
        tensor_labels = torch.tensor(labels).clone().detach()
        loss = learner.loss_func(preds, tensor_labels)
        #Compute gradients
        loss.backward()
        #Take step with optimizer
        learner.opt.step(model = learner, n_eigs = 4, current_loss = loss.clone().detach())
        #Add loss
        epoch_loss += loss.item()
        #Squeeze predictions and true labels in preperation for accuracy computation
        predicted_labels.squeeze_()
        labels.squeeze_()
        #Compute accuracy - convert to boolean mask, sum, and then into scalar to be averaged and processed
        acc = 100 * torch.eq(torch.tensor(predicted_labels), torch.tensor(labels)).sum().item() / batch_size
        #Add to total accuracies
        epoch_accuracies += acc
        #print("LOSS: {} ACCURACY: {}".format(loss, acc))
        #Check if we are on the last sample - break the loop and print the results for this epoch.
        #Append data to loss and accuracy arrays first
        losses.append(loss)
        accuracies.append(acc)
        #Remove inputs and labels from GPU
        del inputs
        del labels
        del predicted_labels
        del acc
        del preds
        torch.cuda.empty_cache()
        if batch_index == n_samples_per_epoch - 1:
            break
    #Compute average loss
    avg_loss = epoch_loss/len(learner.dls.train)
    #Same with accuracy
    avg_accuracy = epoch_accuracies/len(learner.dls.train)
    #Print loss
    print('EPOCH [{}/{}] (TRAINING) - LOSS: {} ACCURACY: {} '.format(epoch + 1, epochs, avg_loss, avg_accuracy))

#Function for validation loop
def perform_model_valid_loop(learner, n_samples_per_epoch, losses, accuracies, epoch, epochs):
    #Loss and accuracy metrics
    epoch_loss = 0.0
    epoch_accuracies = 0.0
    #If no samples per epoch are specified, set to the length of the dataloader
    if n_samples_per_epoch == None: n_samples_per_epoch = len(learner.dls.valid)
    #Iterate over dataloader
    for batch_index, batch_data in enumerate(learner.dls.valid, 0):
        with torch.no_grad():
            #Get inputs and labels
            inputs, labels = batch_data
            #Move to GPU
            inputs, labels = inputs.to('cuda'), labels.to('cuda')
            #Set gradients to zero
            learner.zero_grad()
            #Predict based on inputs (find argmax labels)
            preds = learner.model(inputs)
            predicted_labels = torch.argmax(preds, dim = 1)
            #Find loss from true labels - convert both inputs and labels to tensors first
            tensor_labels = torch.tensor(labels).clone().detach()
            loss = learner.loss_func(preds, tensor_labels)
            #Add loss
            epoch_loss += loss.item()
            #Squeeze predictions and true labels in preperation for accuracy computation
            predicted_labels.squeeze_()
            labels.squeeze_()
            #Compute accuracy - convert to boolean mask, sum, and then into scalar to be averaged and processed
            acc = 100 * torch.eq(torch.tensor(predicted_labels), torch.tensor(labels)).sum().item() / batch_size
            #Add to total accuracies
            epoch_accuracies += acc
            #print("LOSS: {} ACCURACY: {}".format(loss, acc))
            #Check if we are on the last sample - break the loop and print the results for this epoch.
            #Append data to loss and accuracy arrays first
            losses.append(loss)
            accuracies.append(acc)
            #Remove inputs and labels, accuracy, and predictions from GPU (combating potential memory leaks)
            del inputs
            del labels
            del predicted_labels
            del acc
            del preds
            torch.cuda.empty_cache()
        if batch_index == n_samples_per_epoch - 1:
            break
    #Compute average loss
    avg_loss = epoch_loss/n_samples_per_epoch
    #Same with accuracy
    avg_accuracy = epoch_accuracies/n_samples_per_epoch
    #Print loss
    print('EPOCH [{}/{}] (VALIDATION) - LOSS: {} ACCURACY: {} '.format(epoch + 1, epochs, avg_loss, avg_accuracy))


### **SPRKD TEACHER ENSEMBLE TRAINING.**

In [18]:
#Train ensemble of teacher models (3)
#Define loss function (cross entropy)
CE_Loss = nn.CrossEntropyLoss()  
#Define number of epochs and number of samples per epoch
n_epochs = 2
n_samples_per_epoch = 700
#Define how often saddle point checking should take place
saddle_steps = 1
#Initialize saddle point dictionary
SADDLE_POINTS = {}
#Train first teacher model
train_teachers(n_teacher = 0, n_epochs = n_epochs, n_samples_per_epoch = n_samples_per_epoch,
               architecture = MALARIA_CNN_ARCHITECTURE, experiment = "MALARIA", loss_function = CE_Loss, 
               saddle_steps = saddle_steps, saddle_points = SADDLE_POINTS)

STEP 1
DETERMINING SADDLE POINT PRESENCE.


  tensor_labels = torch.tensor(labels).clone().detach()


NEG EIG SUM: 1.1199682354927063
POS EIG SUM: 2.26741024851799
TOP EIGENVALUES: [1.8279528617858887, -0.6201872229576111, 0.43945738673210144, -0.4997810125350952]
APPENDED SADDLE POINTS.
STEP 2
DETERMINING SADDLE POINT PRESENCE.


  acc = 100 * torch.eq(torch.tensor(predicted_labels), torch.tensor(labels)).sum().item() / batch_size


NEG EIG SUM: 0.8903545439243317
POS EIG SUM: 1.9828246533870697
TOP EIGENVALUES: [1.5607473850250244, -0.4933721423149109, 0.4220772683620453, -0.3969824016094208]
APPENDED SADDLE POINTS.
STEP 3
DETERMINING SADDLE POINT PRESENCE.
NEG EIG SUM: 0.31897637993097305
POS EIG SUM: 1.0558243989944458
TOP EIGENVALUES: [1.0558243989944458, -0.052079878747463226, -0.1319534033536911, -0.13494309782981873]
STEP 4
DETERMINING SADDLE POINT PRESENCE.
NEG EIG SUM: 0.9789082705974579
POS EIG SUM: 1.6453723907470703
TOP EIGENVALUES: [1.1312706470489502, -0.45272132754325867, 0.5141017436981201, -0.5261869430541992]
APPENDED SADDLE POINTS.
STEP 5
DETERMINING SADDLE POINT PRESENCE.
NEG EIG SUM: 0.35481879115104675
POS EIG SUM: 1.7685806155204773
TOP EIGENVALUES: [1.0380113124847412, 0.39525526762008667, -0.35481879115104675, 0.3353140354156494]
STEP 6
DETERMINING SADDLE POINT PRESENCE.
NEG EIG SUM: 0.4908277216600254
POS EIG SUM: 1.9933186769485474
TOP EIGENVALUES: [1.9933186769485474, -0.463016837835311

  tensor_labels = torch.tensor(labels).clone().detach()
  acc = 100 * torch.eq(torch.tensor(predicted_labels), torch.tensor(labels)).sum().item() / batch_size


EPOCH [1/2] (VALIDATION) - LOSS: 0.647927439875073 ACCURACY: 65.37905092592592 
STEP 324
DETERMINING SADDLE POINT PRESENCE.
NEG EIG SUM: 2.882235527038574
POS EIG SUM: 220.19106554985046
TOP EIGENVALUES: [183.11793518066406, 34.26161575317383, 2.8115146160125732, -2.882235527038574]
STEP 325
DETERMINING SADDLE POINT PRESENCE.
NEG EIG SUM: 2.4515175819396973
POS EIG SUM: 268.1273498535156
TOP EIGENVALUES: [231.7049560546875, 33.931915283203125, 2.490478515625, -2.4515175819396973]
STEP 326
DETERMINING SADDLE POINT PRESENCE.
NEG EIG SUM: 0
POS EIG SUM: 243.14390969276428
TOP EIGENVALUES: [193.9462432861328, 40.34708023071289, 3.1337945461273193, 5.71679162979126]
STEP 327
DETERMINING SADDLE POINT PRESENCE.
NEG EIG SUM: 4.3827104568481445
POS EIG SUM: 218.14737701416016
TOP EIGENVALUES: [194.5545654296875, 19.124855041503906, 4.46795654296875, -4.3827104568481445]
STEP 328
DETERMINING SADDLE POINT PRESENCE.
NEG EIG SUM: 4.926212310791016
POS EIG SUM: 242.05219554901123
TOP EIGENVALUES: [2

'TRAINING COMPLETE.'

In [19]:
#Save checkpoint
torch.save(TEACHER_1_MALARIA, "TRUE_TEACHER_1_MALARIA.pth")

In [20]:
#Train the second teacher model
train_teachers(n_teacher = 1, n_epochs = n_epochs, n_samples_per_epoch = n_samples_per_epoch,
               architecture = MALARIA_CNN_ARCHITECTURE, experiment = "MALARIA", loss_function = CE_Loss, 
               saddle_steps = saddle_steps, saddle_points = SADDLE_POINTS)

STEP 1
DETERMINING SADDLE POINT PRESENCE.


  tensor_labels = torch.tensor(labels).clone().detach()


NEG EIG SUM: 0.44139184057712555
POS EIG SUM: 1.5440639853477478
TOP EIGENVALUES: [1.2688438892364502, -0.29639098048210144, 0.2752200961112976, -0.1450008600950241]
STEP 2
DETERMINING SADDLE POINT PRESENCE.


  acc = 100 * torch.eq(torch.tensor(predicted_labels), torch.tensor(labels)).sum().item() / batch_size


NEG EIG SUM: 0.8997754454612732
POS EIG SUM: 1.88761305809021
TOP EIGENVALUES: [1.3754276037216187, 0.5121854543685913, -0.435833603143692, -0.4639418423175812]
APPENDED SADDLE POINTS.
STEP 3
DETERMINING SADDLE POINT PRESENCE.
NEG EIG SUM: 1.3371942639350891
POS EIG SUM: 3.2945693731307983
TOP EIGENVALUES: [2.7791192531585693, -0.7451940774917603, 0.515450119972229, -0.5920001864433289]
APPENDED SADDLE POINTS.
STEP 4
DETERMINING SADDLE POINT PRESENCE.
NEG EIG SUM: 0.14354152977466583
POS EIG SUM: 2.1859364807605743
TOP EIGENVALUES: [1.909410834312439, 0.13999006152153015, -0.14354152977466583, 0.13653558492660522]
STEP 5
DETERMINING SADDLE POINT PRESENCE.
NEG EIG SUM: 0.21151798963546753
POS EIG SUM: 2.195126235485077
TOP EIGENVALUES: [1.8631051778793335, -0.21151798963546753, 0.1994229406118393, 0.1325981169939041]
STEP 6
DETERMINING SADDLE POINT PRESENCE.
NEG EIG SUM: 1.7069233059883118
POS EIG SUM: 2.6501673460006714
TOP EIGENVALUES: [1.6315293312072754, 1.018638014793396, -0.900511

Traceback (most recent call last):


NEG EIG SUM: 1.6395666003227234
POS EIG SUM: 3.9840882420539856
TOP EIGENVALUES: [3.3469717502593994, -0.9165865182876587, -0.7229800820350647, 0.6371164917945862]
APPENDED SADDLE POINTS.
STEP 20
DETERMINING SADDLE POINT PRESENCE.
NEG EIG SUM: 0
POS EIG SUM: 2.6394958794116974
TOP EIGENVALUES: [1.8935425281524658, 0.20900648832321167, 0.2947607934474945, 0.2421860694885254]
STEP 21
DETERMINING SADDLE POINT PRESENCE.
NEG EIG SUM: 0.2389225959777832
POS EIG SUM: 2.342095948755741
TOP EIGENVALUES: [2.219322681427002, -0.1245252713561058, 0.12277326732873917, -0.1143973246216774]
STEP 22
DETERMINING SADDLE POINT PRESENCE.
NEG EIG SUM: 0.3493492901325226
POS EIG SUM: 2.625719293951988
TOP EIGENVALUES: [2.0308942794799805, -0.3493492901325226, 0.3504764139652252, 0.24434860050678253]
STEP 23
DETERMINING SADDLE POINT PRESENCE.
NEG EIG SUM: 0.7661405205726624
POS EIG SUM: 3.485051929950714
TOP EIGENVALUES: [1.751258373260498, 1.0190974473953247, -0.7661405205726624, 0.7146961092948914]
STEP 24

  tensor_labels = torch.tensor(labels).clone().detach()
  acc = 100 * torch.eq(torch.tensor(predicted_labels), torch.tensor(labels)).sum().item() / batch_size


EPOCH [1/2] (VALIDATION) - LOSS: 0.6481676692212069 ACCURACY: 65.2488425925926 
STEP 324
DETERMINING SADDLE POINT PRESENCE.
NEG EIG SUM: 7.032716274261475
POS EIG SUM: 423.18100357055664
TOP EIGENVALUES: [373.58758544921875, 42.38432693481445, 7.2090911865234375, -7.032716274261475]
STEP 325
DETERMINING SADDLE POINT PRESENCE.
NEG EIG SUM: 7.703215599060059
POS EIG SUM: 420.218035697937
TOP EIGENVALUES: [358.8130798339844, 53.63907241821289, -7.703215599060059, 7.765883445739746]
STEP 326
DETERMINING SADDLE POINT PRESENCE.
NEG EIG SUM: 6.6386613845825195
POS EIG SUM: 366.0487151145935
TOP EIGENVALUES: [315.8748779296875, 43.29475021362305, 6.879086971282959, -6.6386613845825195]
STEP 327
DETERMINING SADDLE POINT PRESENCE.
NEG EIG SUM: 3.229785680770874
POS EIG SUM: 208.75269556045532
TOP EIGENVALUES: [173.5936279296875, 31.012033462524414, 4.147034168243408, -3.229785680770874]
STEP 328
DETERMINING SADDLE POINT PRESENCE.
NEG EIG SUM: 1.2087864875793457
POS EIG SUM: 285.0768630504608
TOP

'TRAINING COMPLETE.'

In [21]:
#Train the third teacher model
train_teachers(n_teacher = 2, n_epochs = n_epochs, n_samples_per_epoch = n_samples_per_epoch,
               architecture = MALARIA_CNN_ARCHITECTURE, experiment = "MALARIA", loss_function = CE_Loss, 
               saddle_steps = saddle_steps, saddle_points = SADDLE_POINTS)

STEP 1
DETERMINING SADDLE POINT PRESENCE.


  tensor_labels = torch.tensor(labels).clone().detach()


NEG EIG SUM: 0.30323904752731323
POS EIG SUM: 1.4906037151813507
TOP EIGENVALUES: [1.114320158958435, -0.30323904752731323, 0.27137622237205505, 0.1049073338508606]
STEP 2
DETERMINING SADDLE POINT PRESENCE.


  acc = 100 * torch.eq(torch.tensor(predicted_labels), torch.tensor(labels)).sum().item() / batch_size


NEG EIG SUM: 0.851729691028595
POS EIG SUM: 2.7216185331344604
TOP EIGENVALUES: [1.343822956085205, -0.851729691028595, 0.7510522603988647, 0.6267433166503906]
STEP 3
DETERMINING SADDLE POINT PRESENCE.
NEG EIG SUM: 0.2565838545560837
POS EIG SUM: 1.1051031053066254
TOP EIGENVALUES: [0.9446449279785156, 0.16045817732810974, -0.15222178399562836, -0.10436207056045532]
STEP 4
DETERMINING SADDLE POINT PRESENCE.
NEG EIG SUM: 0.9123912155628204
POS EIG SUM: 2.074823021888733
TOP EIGENVALUES: [1.6696124076843262, -0.41140374541282654, -0.5009874701499939, 0.40521061420440674]
APPENDED SADDLE POINTS.
STEP 5
DETERMINING SADDLE POINT PRESENCE.
NEG EIG SUM: 0.5672492831945419
POS EIG SUM: 1.74656081199646
TOP EIGENVALUES: [1.3875224590301514, -0.3590136170387268, 0.3590383529663086, -0.20823566615581512]
STEP 6
DETERMINING SADDLE POINT PRESENCE.
NEG EIG SUM: 0.4416693449020386
POS EIG SUM: 1.972461313009262
TOP EIGENVALUES: [1.0807068347930908, 0.4875892996788025, -0.4416693449020386, 0.404165178

  tensor_labels = torch.tensor(labels).clone().detach()
  acc = 100 * torch.eq(torch.tensor(predicted_labels), torch.tensor(labels)).sum().item() / batch_size


EPOCH [1/2] (VALIDATION) - LOSS: 0.635880269938045 ACCURACY: 63.845486111111114 
STEP 324
DETERMINING SADDLE POINT PRESENCE.
NEG EIG SUM: 3.8386948108673096
POS EIG SUM: 470.175799369812
TOP EIGENVALUES: [399.2275695800781, 60.61112976074219, 10.3371000289917, -3.8386948108673096]
STEP 325
DETERMINING SADDLE POINT PRESENCE.
NEG EIG SUM: 0
POS EIG SUM: 336.1226110458374
TOP EIGENVALUES: [278.0054626464844, 21.394611358642578, 24.890689849853516, 11.831847190856934]
STEP 326
DETERMINING SADDLE POINT PRESENCE.
NEG EIG SUM: 1.72158682346344
POS EIG SUM: 526.7120342254639
TOP EIGENVALUES: [454.40692138671875, 56.682456970214844, 15.622655868530273, -1.72158682346344]
STEP 327
DETERMINING SADDLE POINT PRESENCE.
NEG EIG SUM: 0
POS EIG SUM: 534.6972651481628
TOP EIGENVALUES: [481.9617614746094, 36.423301696777344, 11.964614868164062, 4.3475871086120605]
STEP 328
DETERMINING SADDLE POINT PRESENCE.
NEG EIG SUM: 4.061371803283691
POS EIG SUM: 341.39030027389526
TOP EIGENVALUES: [313.7548828125, 2

'TRAINING COMPLETE.'

In [None]:
#IF OUTPUT PRESERVATION NEEDED
#Train the third teacher model
#Define loss function (cross entropy)
CE_Loss = nn.CrossEntropyLoss()  
#Define number of epochs and number of samples per epoch
n_epochs = 1
n_samples_per_epoch = 700
#Initialize saddle point dictionary
SADDLE_POINTS = {}
train_teachers(n_teacher = 0, n_epochs = n_epochs, n_samples_per_epoch = n_samples_per_epoch,
               architecture = MALARIA_CNN_ARCHITECTURE, experiment = "MALARIA", loss_function = CE_Loss, saddle_points = SADDLE_POINTS)

In [22]:
#Save Saddle Points
torch.save(SADDLE_POINTS, "TRUE_MALARIA_ENSEMBLE_TEACHER_SADDLE_POINTS.pth")

### **SPRKD STUDENT TRAINING SETUP.**

In [7]:
#Train student function
def train_student(n_epochs: int, architecture, n_samples_per_epoch: int, experiment: str, loss_function, is_control: bool, teacher_saddle_points: dict, cooldown_steps: int,
                  decay: int, selfKD: bool, stride: int, epsilon: int, PGD_delta: int, max_hessian_neg_eigensteps: int, PGD_epoch_limit : int):
  #Check experiment type and set dataloader
  dataloader = globals()[experiment + "_TRAIN_DATALOADER"]
  #Create learner object
  student_learner = globals()["STUDENT_" + experiment] = Learner(dataloader, deepcopy(architecture), metrics = "accuracy", loss_func = loss_function)
  student_learner.model = student_learner.model.to('cuda')
  #Find optimal learning rate
  lr = 0.001
  #lr = globals()["STUDENT_" + str(experiment) + "_optimalLR"] = student_learner.lr_find()[0]
  #Set loss function
  student_learner.loss_func = loss_function
  #Set Adam optimizer to take the actual steps
  ADAM_opt = Adam(student_learner.parameters(), lr = lr)
  #Set SPRKD for saddle point tracking and reversions
  optimizer = globals()["STUDENT_" + experiment].opt = SPRKD(params = student_learner.parameters(), loss_function = loss_function, 
                                                                                        stepsize = lr, bias = 0.001, generosity = 5, saddle_steps = None, 
                                                                                        is_teacher = False, is_control = is_control, teacher_saddle_points = teacher_saddle_points, optimizer = ADAM_opt,
                                                                                        epsilon = epsilon, cooldown_steps = cooldown_steps, decay = decay, selfKD = selfKD, stride = 1, 
                                                                                        PGD_delta = PGD_delta, max_hessian_neg_eigensteps = max_hessian_neg_eigensteps, PGD_epoch_limit =PGD_epoch_limit)
  #Declare loss and validation arrays
  train_losses = globals()["STUDENT_" + experiment + "_TRAIN_LOSSES"] = []
  valid_losses = globals()["STUDENT_" + experiment + "_VALID_LOSSES"] = []
  train_accs = globals()["STUDENT_" + experiment + "_TRAIN_ACCURACIES"] = []
  valid_accs = globals()["STUDENT_" + experiment + "_VALID_ACCURACIES"] = []
  #Begin training
  for epoch in range(n_epochs):
        #Training loop
        perform_model_train_loop(learner = student_learner, n_samples_per_epoch = n_samples_per_epoch, losses = train_losses, accuracies = train_accs, epoch = epoch, epochs = n_epochs)
        #Validation loop
        perform_model_valid_loop(learner = student_learner, n_samples_per_epoch = None, losses = valid_losses, accuracies = valid_accs, epoch = epoch, epochs = n_epochs)

In [8]:
#Load all teacher saddle points
SADDLE_POINTS = torch.load("MALARIA_ENSEMBLE_TEACHER_SADDLE_POINTS.pth")
print(SADDLE_POINTS.keys())

dict_keys([0, 1, 2])


In [9]:
#Initialize GPU copy of saddle points
GPU_SADDLE_POINTS = SADDLE_POINTS.copy()
for teacher, teacher_saddle_points in SADDLE_POINTS.items():
  for saddle_point_idx, saddle_point in enumerate(teacher_saddle_points): 
    for param_index, param in enumerate(saddle_point): GPU_SADDLE_POINTS[teacher][saddle_point_idx][param_index] = param.to('cuda')

#### **WEIGHT INJECTION PACKAGE (TRANSFER LEARNING BY INJECTION, TLI).**

In [10]:
#OFFICIAL CODE IMPLEMENTATION OF WEIGHT INJECTION FROM https://arxiv.org/pdf/2006.12986.pdf
#SOURCE: https://github.com/maciejczyzewski/tli-pytorch/blob/fc2ba0b1c5c6fec43fe88cda633283d464f0da00/tli.py
#PASTED AS PER USAGE INSTRUCTIONS - THIS IS AN EXTERNAL PACKAGE AND NOT PART OF THE SPRKD DEVELOPMENT PROCESS.

!pip install karateclub
# MUSIC: https://www.youtube.com/watch?v=F3OFIuIXcSo
#        https://www.youtube.com/watch?v=m_ysN9BQm8s
# FIXME: https://arxiv.org/pdf/2006.12986.pdf
# -----> depth/width/kernel level --> (podzial kodu)
# https://github.com/JaminFong/FNA/blob/master/fna_det/tools/apis/param_remap.py
# (paper) Karate Club
# https://arxiv.org/pdf/2003.04819.pdf

"""
weryfikacja:
- [ ] zrobic wizualizacje (reset -> applyied --> GT) dla KD
        --> SCORE/LOSS roznicy rozkladu --> SUMA/MEAN
        --> histogramy jako wrzuta do "folderu" dla warstwy
- [ ] wizualizacja dopasowania [1. matching 2. injection]
- [ ] zrobic exp__tli --> to samo co MNIST (1k)
          tylko modele 2flops moze 2 rozne od siebie
          jeden przeuczony drugi nie --> transfer --> patrzymy jaki score (ACC)
              [train/test mean]
technikalia:
- [ ] do kazdej warstwy dac "prawdopodobienstwo przypisania trans."
        te co maja najwyzsze prawd. (kilka) vs. (one) big boss
        to sa: rescale(X) + (wiekszawaga)*centercrop(X) + iter. mixing(zbioru)
- [ ] drzewiasty algo? similarity hashing? [[LHS]]
        jako dopasowanie!!!!!!!!!!!!!!!!!
- [ ] poczatkowe warstwy maja "wieksza wage"/"wieksze warstwy"
        --> zrobic jakas uczciwa krzywa z palca 100 -> 75 na ostatnich warstwach
- [ ] mixowanie wiele sieci z `results-imagenet`
            -> az ***nasycimy*** wszystkie wagi
- [ ] uzywanie `trace_graph` --> a nie "modulow" (uwzglednienie relacji)
- [ ] !!! UZYC graph cluster-ingu // zamiast DP
dodatki:
- [ ] FIXME: a co z reszta? np. ._bn0.num_batches_tracked
            ----------> model.state_dict().keys()
            zrobic cos w stylu --> with_meta = True
- [ ] FIXME: zrobic "szybkie" szukanie najlepszych modeli z ImageNet
        jesli ktos zdefiniuje [[auto=True]]
- [ ] analiza: https://github.com/KamilPiechowiak/weights-transfer/pull/17/files
- [ ] sprawdzic czy dziala [WS]/aug/grid tutaj?
- [ ] analiza: https://github.com/mortezamg63/Accessing-and-modifying-different-layers-of-a-pretrained-model-in-pytorch/blob/master/README.md
- [ ] jakas porzadna nazwa np.
        yak shaving (use urban dictionary) // sponge function
                ---> still from crypto name / Unsponge ducktape
                ducktransfer
- [ ] analiza: https://github.com/MSeal/agglom_cluster
- [ ] wielopoziomowe dopasowanie/clustry (a nie standaryzacja):
        (zagniezdzone clustry)
    - in/mid/out -> block -> branch -> grupa tensorow -> tensor -> itp.
"""

# commit: dark tensor rises

import collections
import os
import random
import sys
# FIXME: repair config "reinit" case
from copy import copy
from typing import Dict, List

import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import torch
import torch.nn as nn
from graphviz import Digraph
from karateclub import FeatherNode, NetMF
from networkx.drawing.nx_agraph import graphviz_layout
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split
from sklearn.neural_network import MLPRegressor
from torch.autograd import Variable

use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

################################################################################
# API
################################################################################


def apply_tli(model, teacher=None):
    # print(f"[TLI]   model={model}")
    # print(f"[TLI] teacher={teacher}")
    model_teacher = str_to_model(teacher)
    transfer(model_teacher, model)
    return model


def get_tli_score(model_from, model_to):
    model_a = str_to_model(model_from)
    model_b = str_to_model(model_to)
    sim, _, _, _ = transfer(model_a, model_b)
    return sim


def get_model_timm(name="dla46x_c"):
    try:
        import timm
    except:
        raise Exception("timm package is not installed! try `pip install timm`")

    # FIXME: `channels`!!! and `classes`!!! as param (debug)
    model = timm.create_model(name, num_classes=10, in_chans=3, pretrained=True)
    return model


# FIXME: move to class ModelObj
def str_to_model(name):
    if isinstance(name, str):
        print(f"loading `{name}` from pytorch-image-models...")
        model = get_model_timm(name)
    else:  # FIXME: check if "pytorch" model
        model = name
    return model


# def get_tli_score(model_from, model_to):
#     model_a = str_to_model(model_from)
#     model_b = str_to_model(model_to)
#     score_ab = transfer(model_a, model_b)
#     score_ba = transfer(model_b, model_a)
#     sim = (score_ab + score_ba) / 2
#     print(
#         f"[score_ab={round(score_ab, 2):6} score_ba={round(score_ba, 2):6} | sim={round(sim, 2):6}]"
#     )
#     return sim


################################################################################
# Utils
################################################################################


def apply_hard_reset(model):
    for layer in model.modules():
        if hasattr(layer, "reset_parameters"):
            nn.init.zeros_(layer.weight)
            if layer.bias is not None:
                nn.init.zeros_(layer.bias)
    return model


def fn_inject(from_tensor, to_tensor):
    # FIXME: debug -> vis. -> rescale
    from_slices, to_slices = [], []
    for a, b in zip(from_tensor.shape, to_tensor.shape):
        if a < b:
            from_slices.append(slice(0, a))
            to_slices.append(slice((b - a) // 2, -((b - a + 1) // 2)))
        elif a > b:
            from_slices.append(slice((a - b) // 2, -((a - b + 1) // 2)))
            to_slices.append(slice(0, b))
        else:
            from_slices.append(slice(0, a))
            to_slices.append(slice(0, b))
    to_tensor[tuple(to_slices)] = from_tensor[tuple(from_slices)]


################################################################################
################################################################################
################################################################################

# FIXME: ladnie podzielic na "matching" / "injection"

##########################################################
# --> fn_stats() -> [abs.mean(), rozklad()]
# --> fn_kullbeck(stats1, stats2) -> [0, 1]
# FIXME: pretty list of modules? --> fn_stats
#                                    if GT -> fn_kullbeck
##########################################################

# DIST: https://github.com/timtadh/zhang-shasha
# "graph matching" https://arxiv.org/pdf/1904.12787.pdf
# https://github.com/deepmind/deepmind-research/tree/master/graph_matching_networks
# Graph / Node [Embedding???] / Graph2Vec
# https://github.com/Jacobe2169/GMatch4py
###############
# https://github.com/benedekrozemberczki/awesome-graph-classification
# WeisfeilerLehman ??????

################# BEST ###################
## https://karateclub.readthedocs.io/en/latest/notes/introduction.html
## SELF-LEARN? ---> weak-estimators???

# https://github.com/topics/graph2vec

# [FIXME] konstytucja rewolucjonisty
# kazdy blok ma: [[ a) pozycje b) strukture c) rozmiar ]]
# taki musi byc tez *score*

# FIXME: matching blokow to prekalkulacja scorow
# --> rozwazania beda tensor vs tensor

# STEPS:
# (1) (s-match) cluster vs cluster (top k/percentile)
# (2) (d-match) cluster (iterate -> tensor // double in-tree/out-tree)
# (3) (w-inject)/(k-inject) combo-inject

# pozycja -> level
# struktura -> treedist(edges)

# meta-learning?

#       a) pozycje (normalized)
#       b) strukture (graph features)
#       c) rozmiar (tensor)

# AS DIFF (a -> b)
# [a] (position, structure) -
# [b] (position, structure)

# SELF-LEARN / [[[SELF-ENCODER?]]]

# AS UNSUPERVISED? --> data augmentation / [permutation]??????
#                  --> rozne tensory z tego modulu sa?????????
# [a] [(position, structure), (tensor)] --> 0/1

# label: y = [1 -> to ten tensor], [0 -> to nie ten tensor]

# GENIUS =========
# self-learn? permutacje? graph2vec na samym sobie?
# uczy sie rozpoznawac jaki to tensor xd
# https://arxiv.org/pdf/2010.12878.pdf
# ================
# ??????????????
# http://ryanrossi.com/pubs/KDD18-graph-attention-model.pdf

# FIXME: "faster and naive alternative"
# after scoring [size] -->> [maximum weight bipartite b-matching]

########################## DRAFT ###############################################

# FIXME: [ALWAYS FIT ALL?] model.fit()
# FIXME: [ALWAYS in-tree/out-tree split]
# FIXME: graph embeddings? from all graphs? (EVEN student+teacher)
#                or per `cluster` --> graph embedding? (1)
#               and per `graph` --> graph embedding? (2)
#                      so i have then 2 different embeddings space...

# N basic    --> (nodeE, shape)
# N advanced --> (nodeE, graphE in-tree, graphE out-tree, shape)
# FIXME: there is a problem with "NODE Embedding"?
#                --> change to split in/out graph embedding?

# [S + G + N], [N_i] ---> 0/1 [smooth labeling 0.5/0.25 itc.]
# 1) structure (graph embedding)
# 2) graph (graph embedding)
# 3) node (node embedding)

# ---> laczy na chwile graf [student-teacher] --> cos na tym robi?
# ---> duzo prostsze featury // --> l / max(l) --> c / max(c)

# [XXX] READ THIS: https://markheimann.github.io/projects.html
# https://sci-hub.se/https://link.springer.com/chapter/10.1007/978-3-319-93040-4_57

# print(clf.predict(predictionData),'\n')

# model = LinearRegression().fit(X_train, y_train)
# print(model)

# y_hat = downstream_model.predict_proba(X_test)[:, 1]
# auc = roc_auc_score(y_test, y_hat)
# print('AUC: {:.4f}'.format(auc))

############################################################################

# FIXME: BIAS / WEIGHT (wildcard)
# FIXME: split_d for `student` then ensemble for encoder?

# >>> FOR FLOW
# split_map = split_flow_level(graph_teacher)
# pprint(split_map)
# encoded_split_map = encoder_graph(split_map)
# pprint(encoded_split_map)

# >>> FOR CLUSTERS
# for cluster_idx in graph_teacher.cluster_map.keys():
#     split_map = split_cluster_level(graph_teacher, cluster_idx)
#     pprint(split_map)
#     print(f"cluster_idx={cluster_idx}")
#     break

# >>> ALL FOR NODES
# edges = []
# for a, dst in graph_teacher.edges.items():
#     for b in dst:
#         edges.append([a, b])
# obj = encoder_nodes(edges)
# pprint(obj)
# sys.exit(1)

# >>> FOR NODES IN CLUSTER
# for cluster_idx in graph_teacher.cluster_map.keys():
#     obj = encoder_nodes(graph_teacher.cluster_map[cluster_idx].edges)
#     pprint(obj)
#     print("="*30)

# FIXME: KD-tree?
# FIXME: zrobic wizualizacje matchingu!!!!!!!!!!!!!!!!!!!!!!!
#     (przetestowac laczac ze soba 2 tensory)
#     (dodatek - wizualizacja dodatkowych `edges` do debugu)

#### [[[[[Fast Network Alignment]]]]]]] / xNetMF

# XXX XXX XXX XXX XXX [READ THIS] #######################
# https://gemslab.github.io/papers/heimann-2018-regal.pdf
# https://github.com/GemsLab/REGAL
#########################################################

# class NodeFeatures
#   [a] structures_info
#   [b] graph_info
#   [c] ???? shape
# for multiple matches [[ SparseMAP ]]
# ---> https://arxiv.org/pdf/1802.04223.pdf

# KD-tree? for representations?
# ----> MATRIX???

# matching if provided map


################################################################################
################################################################################
################################################################################

# XXX XXX XXX XXX XXX [READ THIS] #######################
# https://gemslab.github.io/papers/heimann-2018-regal.pdf
# https://github.com/GemsLab/REGAL
#########################################################


def get_networkx(edges, dag=True):
    if dag:
        G = nx.DiGraph()
    else:
        G = nx.Graph()
    G.add_edges_from(edges)
    return G


def show_networkx(graph):
    if isinstance(graph, list):
        graph = get_networkx(edges=graph)
    pos = graphviz_layout(graph, prog="dot")
    nx.draw(graph, pos, with_labels=True, arrows=True)
    plt.show()


def dag_split(edges, token, root=None):
    graph = {}
    for a, b in edges:
        if a not in graph:
            graph[a] = []
        if b not in graph:
            graph[b] = []
        graph[a].append(b)
        graph[b].append(a)
    edges_split = []
    visited, queue = set(), collections.deque([root])
    while queue:
        stop = False
        node_root = queue.popleft()
        if node_root not in graph:
            continue
        if node_root == token:
            break
        for node in graph[node_root]:
            if node not in visited:
                if node == token:
                    stop = True
                edges_split.append([node_root, node])
                visited.add(node)
                queue.append(node)
        if stop:
            break
    # FIXME: empty graphs?
    if not edges_split:
        edges_split.append([token, token])
    return edges_split


def graph_splits(edges, nodes=False):
    G = get_networkx(edges)
    order = list(nx.topological_sort(G))
    if len(order) == 0:
        return {}
    idx_src, idx_dst = order[0], order[-1]
    if not nodes:
        nodes = set()
        for a, b in edges:
            nodes.add(a)
            nodes.add(b)
    split_map = {}
    for idx in nodes:
        in_tree = dag_split(edges, idx, root=idx_src)
        out_tree = dag_split(edges, idx, root=idx_dst)
        split_map[idx] = {"in-tree": in_tree, "out-tree": out_tree}
    return split_map


def graph_norm(edges, attr=None):
    normal_id_map = {}
    normal_id_iter = [0]
    rev_mask = {}

    def __for_single(idx):
        if not idx in normal_id_map:
            normal_id_map[idx] = normal_id_iter[0]
            rev_mask[normal_id_iter[0]] = idx
            normal_id_iter[0] += 1

    # random.shuffle(edges)

    for a, b in edges:
        __for_single(a)
        __for_single(b)

    norm_edges = []
    for a, b in edges:
        norm_edges.append([normal_id_map[a], normal_id_map[b]])

    # norm_edges = sorted(norm_edges)

    norm_attr = []
    if attr:
        for i in range(len(normal_id_map.keys())):
            norm_attr.append(attr[rev_mask[i]])

    return norm_edges, rev_mask, norm_attr


def utils_map_to_mask(split_map):
    mask, graphs = [], []
    for key, split_dict in split_map.items():
        for dict_key in split_dict.keys():
            _g, rev_mask, _ = graph_norm(split_dict[dict_key])
            g = get_networkx(_g, dag=False)
            mask.append([key, dict_key])
            graphs.append(g)
    return mask, graphs


def utils_mask_to_map(mask, X):
    split_map = {}
    for i, (key, dict_key) in enumerate(mask):
        if key not in split_map:
            split_map[key] = {}
        split_map[key][dict_key] = X[i]
    return split_map


################################################################################


def split_flow_level(graph):
    edges = []
    for edge in graph.cluster_links:
        cluster_idx_1 = graph.nodes[edge[0]].cluster_idx
        cluster_idx_2 = graph.nodes[edge[1]].cluster_idx
        edges.append([cluster_idx_1, cluster_idx_2])
    return graph_splits(edges)


def split_cluster_level(graph, cluster_idx):
    edges = graph.cluster_map[cluster_idx].edges
    return graph_splits(edges)


def encode_graph(split_map):
    mask, graphs = utils_map_to_mask(split_map)

    # FIXME: move to settings
    from karateclub import GL2Vec

    model = GL2Vec(dimensions=16) #FeatherGraph(eval_points=2, order=2)
    print("FIT")
    model.fit(graphs)
    print("EMBEDDING")
    X = model.get_embedding()
    print("-------------------->", X.shape)

    return utils_mask_to_map(mask, X)


################################################################################
# TLI
################################################################################


class TLIConfig(object):
    def __init__(self, adict):
        self.__dict__.update(adict)

from karateclub import Diff2Vec
embedding_dim = 5  # best 4, 6, 5 / FIXME: was 9, how to find?
CONFIG = TLIConfig(
    {
        # FIXME: move outsite? --> lazy_load?
        "node_embedding_attributed": FeatherNode( # 2, 4
            eval_points=4, order=4, svd_iterations=100, reduction_dimensions=32
        ),
        "node_embedding_neighbourhood": NetMF(
             dimensions=embedding_dim
        ),  # FIXME: use xNetMF
                # Diff2Vec(diffusion_number=5, diffusion_cover=5, dimensions=embedding_dim),
        "autoencoder": MLPRegressor(
            max_iter=100, # 100 // 3,  # FIXME: best 50
            early_stopping=False,
            activation="relu",
            solver="adam",
            tol=0.0001,
            ##############################################
            # n_iter_no_change=100, # FIXME: is that good?
            ##############################################
            hidden_layer_sizes=(200, 50, 25,),  # 125, 25
            warm_start=True,
            learning_rate_init=0.0005,
            alpha=0.001,
            verbose=True,
        ),
        "test_size": 0.05,  # FIXME: this is important!
        "samples_per_tensor": 10,
    }
)


def E_nodes(edges, attr=None):
    norm_graph, rev_mask, norm_attr = graph_norm(edges, attr=attr)

    if len(rev_mask) == 0:
        return []

    model = (
        CONFIG.node_embedding_attributed
        if attr
        else CONFIG.node_embedding_neighbourhood
    )

    graph = get_networkx(norm_graph, dag=False)
    if attr:
        model.fit(graph, np.array(norm_attr))
        X = model.get_embedding()
    else:
        model.fit(graph)
        X = model.get_embedding()

    print(f"[E_nodes {X.shape}]", end="")

    encoded_nodes = {}
    for i in range(X.shape[0]):
        encoded_nodes[rev_mask[i]] = X[i]
    return encoded_nodes


def F_architecture(graph, mlb=None, mfa=None):
    ### POSITION ENCODING ###
    edges = []
    cluster_feature = {}
    for cluster_idx, cluster in graph.cluster_map.items():
        cluster_feature[cluster_idx] = [len(cluster.nodes) / (1 + len(cluster.edges))]
    for edge in graph.cluster_links:
        cluster_idx_1 = graph.nodes[edge[0]].cluster_idx
        cluster_idx_2 = graph.nodes[edge[1]].cluster_idx
        edges.append([cluster_idx_1, cluster_idx_2])
    P = E_nodes(edges, attr=cluster_feature)

    ### STRUCTURE ENCODING ###
    S = {}
    for cluster_idx in graph.cluster_map.keys():
        edges = graph.cluster_map[cluster_idx].edges
        ## obj = E_nodes(edges)
        if len(edges) > embedding_dim:
            obj = E_nodes(edges)
        else:
            obj = {}
            for idx in graph.cluster_map[cluster_idx].nodes:
                obj[idx] = np.array([0.0] * embedding_dim)  # FIXME: config
        S.update(obj)

    ### NODE ENCODING ###
    N = {}  # FIXME: move to fn_node_encoder?
    vec = []
    for idx, node in graph.nodes.items():
        vec.append(__encode(node.name))
        # vec.append(list(node.name.replace(".weight", "").replace(".bias", "")))
        # vec.append(node.name.split("."))
    vec = mlb.transform(vec)
    vec = mfa.transform(vec)
    vec_final = []
    for i, (idx, node) in enumerate(
        graph.nodes.items()
    ):  # FIXME: better way? [pad len 4]
        _shape4 = nn.ConstantPad1d((0, 4 - len(node.size)), 0.0)(
            torch.tensor(node.size)
        )
        #shape_ab = __shape_score(_shape4.type(torch.FloatTensor), (100, 1, 1, 1))
        #shape_ba = __shape_score(_shape4.type(torch.FloatTensor), (1, 100, 1, 1))
        shape4 = _shape4.type(torch.FloatTensor) / torch.max(1 + _shape4)
        if shape4[0] > shape4[1]:
            rot = 1
        else:
            rot = 0
        _idx_rev = (graph.max_idx - node.idx) / graph.max_idx
        _idx_rev2 = (node.idx) / graph.max_idx
        _level_rev = (graph.max_level - node.level) / graph.max_level
        _level_rev2 = (node.level) / graph.max_level
        _cluster_rev = (graph.max_idx - node.cluster_idx) / graph.max_idx
        _cluster_rev2 = (node.cluster_idx) / graph.max_idx
        _type = 0 if ".bias" in node.name else 1
        # dotcount = node.name.count('.')
        # N[idx] = np.array(
        vec_final.append(np.array(
            [rot]
            + shape4.tolist()
            + [(_idx_rev + _cluster_rev+_level_rev)/3,
               (_idx_rev2+_cluster_rev2+_level_rev2)/3, _type]
        ))
        # vec_final.append(np.array(
        #     # [shape_ab, shape_ba]
        #     [rot]
        #     + shape4.tolist()
        # ))
    from sklearn import preprocessing
    # _pp = preprocessing.QuantileTransformer() # BEST
    # _pp = preprocessing.QuantileTransformer() # 83 / 158
    # _pp = preprocessing.Normalizer(norm='l2') # 77 / 158
    # _pp = preprocessing.Normalizer(norm='l1') # 76 / 158
    # _pp = preprocessing.Normalizer(norm='max') # [78] 79 / 158
    # _pp = preprocessing.PowerTransformer() # 80 / 158
    # _pp = preprocessing.MaxAbsScaler() #XXX 20 77 / 158
    # _pp = preprocessing.RobustScaler() # 78 / 158
    _pp = preprocessing.StandardScaler() #XXX 85 / 158
    # _pp = preprocessing.KBinsDiscretizer(n_bins=10, encode='ordinal',
    #                                      strategy='quantile') # 75
    vec_final = _pp.fit_transform(vec_final)

    for i, (idx, node) in enumerate(
        graph.nodes.items()
    ):
        # FIXME???????? without vec_final?
        # print(vec_final[i])
        N[idx] = np.array(vec_final[i].tolist() + vec[i].tolist())

    print("(encode_graph ended)")
    return P, S, N


def __q(a, b):
    return np.array(a) + np.array(b)
    # return np.array(a) * np.array(b) # 60 / 158
    # return np.concatenate((a, b), axis=0) # 65 / 158


def __shape_score(s1, s2):
    if len(s1) != len(s2):
        return 0
    score = 1
    for x, y in zip(s1, s2):
        score *= min(x / y, y / x)
    return score


# gen_dataset / `self-learn`
def gen_dataset(graph, P, S, N, EG, prefix=""):
    X, y = [], []

    # FIXME: move to encoder settings? / encoder definition
    for idx, node in graph.nodes.items():
        if node.type != "W":  # FIXME: is it good?
            continue

        cluster_idx = node.cluster_idx
        # FIXME: make it pretty
        # FIXME: encoder score for [N]

        # === CASE 1: [self to self] (q_src, q_dst) -> 1
        for _ in range(CONFIG.samples_per_tensor):
            # FIXME: move to `augmentation`
            p_src = np.array(P[cluster_idx])
            r = np.random.uniform(low=-0.05, high=0.05, size=p_src.shape)
            p_src += r
            s_src = np.array(S[idx])
            r = np.random.uniform(low=-0.05, high=0.05, size=s_src.shape)
            s_src += r
            q_src = p_src.tolist() + s_src.tolist() + list(N[idx]) + \
                EG[f"{prefix}_{idx}"]["in-tree"].tolist()
            X.append(__q(q_src, q_src))
            # FIXME: verify 0.05, 0.05? maybe add as std/var
            y.append(1 + np.random.uniform(low=-0.05, high=0.05))

        q_src = list(P[cluster_idx]) + list(S[idx]) + list(N[idx]) + \
            EG[f"{prefix}_{idx}"]["in-tree"].tolist()

        X.append(__q(q_src, q_src))
        y.append(1)

        def __get_node(cluster_idx=None, type=None):
            r_idx = None
            if cluster_idx is not None:
                nodes = list(graph.cluster_map[cluster_idx].nodes)
            else:
                nodes = list(graph.nodes.keys())
            for _ in range(len(N)):
                r_idx = random.choice(nodes)
                if graph.nodes[r_idx].type == type or not type:
                    break
            return r_idx

        # === CASE 2: same cluster, W
        for _ in range(CONFIG.samples_per_tensor):
            r_idx = __get_node(cluster_idx=cluster_idx, type="W")
            r_cluster_idx = cluster_idx
            if idx == r_idx:
                continue

            q_dst = list(P[r_cluster_idx]) + list(S[r_idx]) + list(N[r_idx]) + \
                EG[f"{prefix}_{r_idx}"]["in-tree"].tolist()

            N_bonus = 0
            N_dist = np.linalg.norm(N[idx] - N[r_idx])

            if N_dist <= 1:
                N_bonus = (1 - N_dist) / 4

            X.append(__q(q_src, q_dst))
            y.append(
                N_bonus
                + 0.25
                + 0.5 * __shape_score(graph.nodes[idx].size, graph.nodes[r_idx].size)
            )

        # === CASE 3: other cluster, W
        for _ in range(CONFIG.samples_per_tensor):
            r_idx = __get_node(cluster_idx=None, type="W")
            r_cluster_idx = graph.nodes[r_idx].cluster_idx
            if r_cluster_idx == cluster_idx:
                continue
            if idx == r_idx:
                continue

            q_dst = list(P[r_cluster_idx]) + list(S[r_idx]) + list(N[r_idx]) + \
                EG[f"{prefix}_{r_idx}"]["in-tree"].tolist()

            N_bonus = 0
            N_dist = np.linalg.norm(N[idx] - N[r_idx])

            if N_dist <= 1:
                N_bonus = (1 - N_dist) / 4

            S_bonus = 0
            S_dist = np.linalg.norm(S[idx] - S[r_idx])

            if S_dist <= 1:
                S_bonus = (1 - S_dist) / 4

            X.append(__q(q_src, q_dst))
            y.append(
                N_bonus / 2
                + S_bonus / 2
                + 0.25 * __shape_score(graph.nodes[idx].size, graph.nodes[r_idx].size)
            )

        # === CASE 4: ?, F
        # for _ in range(CONFIG.samples_per_tensor):
        #     r_idx = __get_node(cluster_idx=None, type="F")
        #     r_cluster_idx = graph.nodes[r_idx].cluster_idx
        #     if idx == r_idx:
        #         continue

        #     q_dst = list(P[r_cluster_idx]) + list(S[r_idx]) + list(N[r_idx])

        #     X.append(__q(q_src, q_dst))
        #     y.append(0)

    print("DATASET", np.array(X).shape)#len(y))

    return X, y

# _vec = list(x.replace(".weight", "").replace(".bias", ""))
# # print(_vec)
# _lvl = [s for s in _vec if s.isdigit()]
# _lvl = "".join(_lvl)
# _vec = list(set(_vec))
# if _lvl:
#     _vec.append(_lvl)

def __encode(x):
    x = x.replace(".weight", "").replace(".bias", "")
    x = x.replace("blocks", "")
    if "Backward" in x:
        x = ""
    # print(x)
    _vec = list(x) # + [x]
    # minl, maxl = 1, 2
    # t = x
    # _vec = [t[i:i+j] for i in range(len(t)-minl) for j in range(minl,maxl+1)]
    # print(_vec)
    _lvl = [s for s in _vec if s.isdigit()]
    _lvl = "".join(_lvl)
    _vec = list(set(_vec))
    if _lvl:
        _vec.append(_lvl)
        # for i in range(2, len(_lvl)+1):
        #     _vec.append(_lvl[0:i])
    return _vec

def score_autoencoder(graph_src, graph_dst):
    # src_ids_to_layers_mapping = get_idx_to_layers_mapping(model_src,
    #                                                           graph_src)
    # dst_ids_to_layers_mapping = get_idx_to_layers_mapping(model_dst,
    #                                                           graph_dst)

    from sklearn.preprocessing import MultiLabelBinarizer
    from sklearn.manifold import Isomap

    mlb = MultiLabelBinarizer()

    vec = []
    # FIXME: mutual
    for idx, node in graph_dst.nodes.items():
        # if node.type != "W":
        #    continue
        vec.append(__encode(node.name))
    # for idx, node in graph_src.nodes.items():
    #     vec.append(__encode(node.name))
    #     vec.append(node.name.split("."))
    mlb.fit(vec) # FIXME: 50
    _l1 = len(graph_dst.nodes.keys())
    _l2 = len(graph_dst.cluster_map.keys())
    # print(_l2, _l1)
    mfa = Isomap(n_components=min(_l1//2, 30), n_neighbors=min(_l1//10, 50), p=3) # 30 best
    _vec = mlb.transform(vec)
    mfa.fit(_vec)

    P_src, S_src, N_src = F_architecture(graph_src, mlb=mlb, mfa=mfa)
    P_dst, S_dst, N_dst = F_architecture(graph_dst, mlb=mlb, mfa=mfa)

    from pprint import pprint
    split_map = {}
    for cluster_idx in graph_src.cluster_map.keys():
        _split_map = split_cluster_level(graph_src, cluster_idx)
        for key in _split_map:
            split_map[f"src_{key}"] = _split_map[key]
    print("(graph_src ended)")
    for cluster_idx in graph_dst.cluster_map.keys():
        _split_map = split_cluster_level(graph_dst, cluster_idx)
        for key in _split_map:
            split_map[f"dst_{key}"] = _split_map[key]
    print("(graph_dst ended)")
    EG = encode_graph(split_map)
    # for key in EG:
    #     print("-------->", key, EG[key]["in-tree"].shape)

    X1, y1 = gen_dataset(graph_src, P_src, S_src, N_src, EG, prefix="src")
    X2, y2 = gen_dataset(graph_dst, P_dst, S_dst, N_dst, EG, prefix="dst")
    X = X1 + X2
    y = y1 + y2

    print("DATASET FULL", np.array(X).shape)
    # for x in range(len(X)):
    #     print(np.array(X[x]).shape)

    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=CONFIG.test_size, random_state=42
    )

    ### AUTOENCODER ###

    # https://scikit-learn.org/stable/modules/generated/sklearn.semi_supervised.SelfTrainingClassifier.html#sklearn.semi_supervised.SelfTrainingClassifier

    model = copy(CONFIG.autoencoder)
    ### model.fit(X1, y1)
    ### model.fit(X2, y2)
    model.fit(X_train, y_train)

    #########################

    y_hat = model.predict(X_test)
    loss = mean_squared_error(y_test, y_hat)
    print(f" LOSS --> {loss}")

    #################################################################
    ## FIXME: bipartie_matching between top-k #######################
    ## FIXME: match by clusters --> if best in cluster / eliminate ##
    ## FIXME: try connection? # FIXME: elimination? greedy {top 3} ##
    #################################################################

    ### MATCHING ###

    # FIXME: move to [fn_matcher, fn_scorer]

    def __norm_weights(graph):
        arr, imap, i = [], {}, 0
        for _, (idx, node) in enumerate(graph.nodes.items()):
            if node.type != "W":
                continue
            arr.append(idx)
            imap[idx] = i
            i += 1
        return arr, imap

    src_arr, src_map = __norm_weights(graph_src)
    dst_arr, dst_map = __norm_weights(graph_dst)

    n, m = len(src_arr), len(dst_arr)
    scores = np.zeros((n, m))

    # classes = [
    #         nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d,
    #         nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d,
    #         nn.Linear
    # ]

    for dst_j, idx_dst in enumerate(dst_arr):
        node_dst = graph_dst.nodes[idx_dst]
        dst_type = node_dst.name.split(".")[-1]

        q_dst = (
            list(P_dst[node_dst.cluster_idx])
            + list(S_dst[idx_dst])
            + list(N_dst[idx_dst])
            + list(EG[f"dst_{idx_dst}"]["in-tree"].tolist())
        )

        q_arr = []
        for src_i, idx_src in enumerate(src_arr):
            node_src = graph_src.nodes[idx_src]
            src_type = node_src.name.split(".")[-1]

            q_src = (
                list(P_src[node_src.cluster_idx])
                + list(S_src[idx_src])
                + list(N_src[idx_src])
                + list(EG[f"src_{idx_src}"]["in-tree"].tolist())
            )
            q_arr.append(__q(q_src, q_dst))
            scores[src_i, dst_j] = __shape_score(node_dst.size, node_src.size)

            # src_layer = src_ids_to_layers_mapping[idx_src]
            # dst_layer = dst_ids_to_layers_mapping[idx_dst]

            # not_same_class = True
            # for classname in classes:
            #     if isinstance(src_layer, classname) and \
            #         isinstance(dst_layer, classname):
            #             not_same_class = False
            #             break

            if dst_type != src_type:  # or not_same_class:
                scores[src_i, dst_j] = 0

        y_hat = model.predict(q_arr)
        scores[:, dst_j] *= y_hat

    return scores, src_arr, dst_arr


def transfer(model_src, model_dst=None, teacher=None, inject=True, debug=False):
    # FIXME: replace str to model if needed
    if model_src and model_dst:
        # API: v2
        print("API: V2")
        pass
    elif not model_dst and teacher:
        # API: v1
        print("API: V1")
        model_src, model_dst = teacher, model_src
    else:
        raise Exception("where is teacher?! is this a joke?")

    graph_src = get_graph(model_src)
    graph_dst = get_graph(model_dst)

    if debug:
        show_graph(graph_src, ver=3, path="__tli_src")
        show_graph(graph_dst, ver=3, path="__tli_dst")

    scores, src_arr, dst_arr = score_autoencoder(graph_src, graph_dst)

    remap = {}
    n, m = len(src_arr), len(dst_arr)

    ##############################################

    # for size in np.arange(0.10, 0.50, 0.10):
    #     window_size = size
    #     for _dst_j, idx_dst in enumerate(dst_arr[::-1]):
    #         dst_j = m - _dst_j - 1
    #         ith = dst_j / m
    #         shift = max(int(ith*n - window_size*n), 0)
    #         i = np.argmax(scores[shift:shift+int(window_size*n), dst_j])+shift
    #         if idx_dst not in remap and scores[i, dst_j] > 1 - size:
    #             remap[idx_dst] = src_arr[i]

    beta = 0.5
    smap = copy(scores)
    for _ in range(n*m):
        i, j = np.unravel_index(smap.argmax(), smap.shape)
        smap[i, :] *= beta
        # smap[:, j] *= 0.9 # FIXME
        if dst_arr[j] not in remap:
            smap[:, j] = 0
            remap[dst_arr[j]] = src_arr[i]

    window_size = 0.25
    for _dst_j, idx_dst in enumerate(dst_arr[::-1]):
        dst_j = m - _dst_j - 1
        ith = dst_j / m
        shift = max(int(ith*n - window_size*n), 0)
        i = np.argmax(smap[shift:, dst_j])+shift
        if idx_dst not in remap:
            remap[idx_dst] = src_arr[i]

    window_size = 1
    for _dst_j, idx_dst in enumerate(dst_arr[::-1]):
        dst_j = m - _dst_j - 1
        ith = dst_j / m
        shift = max(int(ith*n - window_size*n), 0)
        i = np.argmax(smap[shift:, dst_j])+shift
        if idx_dst not in remap:
            remap[idx_dst] = src_arr[i]

    ##############################################

    seen = set()
    all_scores = []
    error_n, error_sum = 0, 0
    print(" "*45 + "[[src]]" + " "*30 + "[[dst]]")
    for j, idx_dst in enumerate(dst_arr):
        node_dst = graph_dst.nodes[idx_dst]

        idx_src = remap[idx_dst]
        score = scores[src_arr.index(idx_src), j]  # src_i, dst_i
        all_scores.append(score)

        name_src = graph_src.nodes[idx_src].name
        name_dst = node_dst.name
        color_code = "\x1b[1;37;40m"
        if name_src != name_dst:
            error_sum += 1
            color_code = "\x1b[1;31;40m"

        color_end = "\x1b[0m"
        print(
            f"src= {idx_src:3} | dst= {idx_dst:3} | "
            + f"S= {round(score, 2):4} | {color_code}{name_src:30}{color_end} / "
            + f"{name_dst:10}"
        )

        seen.add(idx_src)
        error_n += 1

    sim = max(0, min(1, np.mean(all_scores)))

    print("=== MATCH =================")
    n = len(graph_src.nodes.keys())
    print(f"  SIM --> \x1b[0;34;40m{round(sim, 4)}\x1b[0m")
    print(f" SEEN --> {len(seen):5} / {n:5} | {round(len(seen)/n,2)}")
    print(f"ERROR --> {error_sum:5} / {error_n:5} | {round(error_sum/error_n,2)}")
    print("===========================")

    #############################################

    # FIXME: dwa razy odpalone?
    # FIXME: choose bigger model to smaller? --> argmax [matrix]
    # FIXME: wes argmax dla wiekszego modelu?
    # FIXME: [(maximum cover, max flow, biparte)]

    if debug:
        # FIXME: do pracy dodac rysunek z sieci typu "debug"
        show_remap(graph_src, graph_dst, remap, path="__tli_remap")

    if inject:
        p_src_ref = {}
        for name, param in model_src.named_parameters():
            p_src_ref[name] = param
        p_dst_ref = {}
        for name, param in model_dst.named_parameters():
            p_dst_ref[name] = param

        with torch.no_grad():
            for idx_dst, idx_src in remap.items():
                node_src = graph_src.nodes[idx_src]
                node_dst = graph_dst.nodes[idx_dst]
                p_src = p_src_ref[node_src.name]
                p_dst = p_dst_ref[node_dst.name]
                fn_inject(p_src, p_dst)

    return sim, remap, graph_src, graph_dst


################################################################################
# Trace Graph
################################################################################


class Node:
    def __init__(self):
        self.idx = 0
        self.var = None
        self.type = None
        self.size = ()
        self.level = 1
        self.cluster_idx = 1


class Graph:
    def __init__(self):
        self.nodes = None
        self.edges = None

        self.cluster_map = None
        self.cluster_links = None

        self.max_level = None
        self.max_idx = None


class Cluster:
    def __init__(self):
        self.cluster_idx = 0
        self.nodes = []
        self.edges = []


def make_graph(var, params=None) -> Graph:
    graph = Graph()  # FIXME: move to CONFIG
    mod_op = ["AddBackward0", "MulBackward0", "CatBackward"]

    if params is not None:
        assert all(isinstance(p, Variable) for p in params.values())
        param_map = {id(v): k for k, v in params.items()}

    def __get_type(var):
        node = Node()
        node.var = var
        if hasattr(var, "variable"):
            u = var.variable
            node_name = param_map[id(u)]
            size = list(u.size())
            node.name = node_name
            node.size = size
            node.type = "W"
        else:
            node_name = str(type(var).__name__)
            if node_name in mod_op:
                node.type = "OP"
            else:
                node.type = "F"
            node.name = node_name
        return node

    normal_id_map = {}
    normal_id_iter = [0]

    def __normal_id(var):
        __pointer_idx = id(var)
        if __pointer_idx in normal_id_map:
            return normal_id_map[__pointer_idx]
        else:
            normal_id_map[__pointer_idx] = normal_id_iter[0]
            normal_id_iter[0] += 1
            return normal_id_iter[0] - 1

    def __bfs(graph, degree=2):
        nodes = {}
        edges = {}
        _rev_edges = {}
        _level_map = {}
        _mod_op_map = {}
        visited, queue = set(), collections.deque([graph])
        while queue:
            var = queue.popleft()
            idx_root = __normal_id(var)
            if idx_root not in _level_map:
                _level_map[idx_root] = 1
            if idx_root not in _mod_op_map:
                _mod_op_map[idx_root] = idx_root
            if idx_root not in nodes:  # FIXME: for root? yes?
                nodes[idx_root] = __get_type(var)
                nodes[idx_root].cluster_idx = idx_root
                nodes[idx_root].type = "OP"
            if idx_root not in edges:
                edges[idx_root] = []
            if idx_root not in _rev_edges:
                _rev_edges[idx_root] = []
            for _u in var.next_functions:
                u = _u[0]
                idx = __normal_id(u)
                if not u:
                    continue
                edges[idx_root].append(idx)
                if idx not in _rev_edges:
                    _rev_edges[idx] = []
                _rev_edges[idx].append(idx_root)
                if u not in visited:
                    _level_map[idx] = _level_map[idx_root] + 1
                    node = __get_type(u)
                    node.idx = idx
                    if node.type == "OP":
                        _mod_op_map[idx] = idx_root
                    else:
                        _mod_op_map[idx] = _mod_op_map[idx_root]
                    node.level = _level_map[idx]
                    node.cluster_idx = _mod_op_map[idx]
                    nodes[idx] = node
                    # print(f"--> {node.name:30} | {_level_map[idx]:10} " + \
                    #      f">> {_mod_op_map[idx]:10}")
                    visited.add(u)
                    queue.append(u)
        ### === split by degree
        ## FIXME: add min. [branch depth?]
        ## FIXME: next tour (remove "dummy nodes" / [is_op->is_op])
        if degree:
            visited, queue = set(), collections.deque([graph])
            for idx_root in _rev_edges:
                # print(f"----> root {nodes[idx_root].name:50} {len(_rev_edges[idx_root])}")
                if len(_rev_edges[idx_root]) >= degree \
                        and nodes[idx_root].type != "W": # FIXME: bug?
                    # print("\t[MATCH]")
                    nodes[idx_root].type = "OP"
            while queue:
                var = queue.popleft()
                idx_root = __normal_id(var)
                for _u in var.next_functions:
                    u = _u[0]
                    idx = __normal_id(u)
                    if not u:
                        continue
                    if u not in visited:
                        node = nodes[idx]
                        if node.type == "OP":
                            _mod_op_map[idx] = idx_root
                        else:
                            _mod_op_map[idx] = _mod_op_map[idx_root]
                        node.cluster_idx = _mod_op_map[idx]
                        nodes[idx] = node
                        visited.add(u)
                        queue.append(u)
        max_level = 0
        for _, node_level in _level_map.items():
            max_level = max(max_level, node_level)
        return nodes, edges, max_level

    if isinstance(var, tuple):
        raise Exception("Lord Dark Tensor: have not implemented that feature")
        sys.exit(1)
        for v in var:
            __bfs(v.grad_fn)
    else:
        # FIXME: option to choose method? (degree=None)
        # FIXME: add to config
        nodes, edges, max_level = __bfs(var.grad_fn)  # , degree=None)

    graph.nodes = nodes
    graph.edges = edges

    # make clusters
    graph.cluster_map, graph.cluster_links = make_clusters(graph)
    if len(graph.cluster_map.keys()) <= 1:
        graph.cluster_links.append([0, 0])

    # graph meta
    graph.max_level = max_level
    graph.max_idx = normal_id_iter[0]

    return graph


def make_clusters(graph):
    cluster_map = {}
    cluster_links = []
    for idx, node in graph.nodes.items():
        if node.cluster_idx not in cluster_map:
            # print(f"creating cluster {node.cluster_idx}")
            cluster_map[node.cluster_idx] = Cluster()
        cluster_map[node.cluster_idx].nodes.append(idx)
    for idx_root, edges in graph.edges.items():
        node_root = graph.nodes[idx_root]
        for idx in edges:
            if graph.nodes[idx].type == "OP":
                cluster_links.append([idx, idx_root])
                continue
            cluster_map[node_root.cluster_idx].edges.append([idx, idx_root])
    return cluster_map, cluster_links


def get_graph(model, input=None):
    # FIXME: (automatic) find `input` size (just arr?) / (32, 1, 31, 31)
    graph = None
    input_shape = [input] if input else [(3, 32, 32), (1, 31, 31), (3, 224, 224)]
    for _input_shape in input_shape:
        x = torch.randn(32, *_input_shape)
        try:
            x = x.to(device) # FIXME: more pretty?
            model = model.to(device)
            graph = make_graph(model(x), params=dict(model.named_parameters()))
            break
        except Exception as err:
            print("ERROR", err)
            continue
    if not graph:
        raise Exception("something really wrong!")
    return graph


def get_idx_to_layers_mapping(model: nn.Module, graph: Graph) -> Dict[int, nn.Module]:
    names_to_layers_mapping = {}

    def dfs(model: nn.Module, name_prefix: List[str]):
        for child_name, child in model.named_children():
            dfs(child, name_prefix + [child_name])
        names_to_layers_mapping[".".join(name_prefix)] = model

    dfs(model, [])

    ids_to_layers_mapping = {}
    for node in graph.nodes.values():
        if node.type == "W":
            node_name = node.name.replace(".weight", "").replace(".bias", "")
            layer = names_to_layers_mapping[node_name]
            ids_to_layers_mapping[node.idx] = layer

    return ids_to_layers_mapping


################################################################################
# Visualization
################################################################################


def make_dot(graph, ver=0, prefix="", rankdir="TB"):
    graph_idx = id(graph)

    node_attr = dict(
        style="filled",
        shape="box",
        align="left",
        fontsize="12",
        ranksep="0.1",
        height="0.2",
        # rank="same"
    )

    graph_attr = dict(
        rank="same",
        # splines="true",
        rankdir=rankdir,  # rankdir,
        # ratio="compress",
        # overlay="compress",
        # quadtree="true",
        # overlap="prism",
        # overlap_scaling="0.01"
    )

    print(f"graph_idx={graph_idx}")
    graph_name = f"cluster_{graph_idx}"  # if rankdir == "TB" else str(graph_idx)
    dot = Digraph(name=graph_name, node_attr=node_attr, graph_attr=graph_attr)

    cluster_map, cluster_links = graph.cluster_map, graph.cluster_links

    def __show_graph_nodes():
        for idx, node in graph.nodes.items():
            _header_name = (
                f"[c = {node.cluster_idx} / "
                + f"l = {node.level} / "
                + f"idx = {node.idx}]\n{node.name}"
            )
            if node.type == "OP":
                dot.node(prefix + str(idx), _header_name, fillcolor="green")
            elif node.type == "W":
                dot.node(
                    prefix + str(idx),
                    _header_name + f"\n{node.size}",
                    fillcolor="lightblue",
                )
            else:
                dot.node(prefix + str(idx), _header_name)

    def __show_graph_edges():
        for idx_root, edges in graph.edges.items():
            for idx in edges:
                dot.edge(prefix + str(idx), prefix + str(idx_root), color="black")

    def __show_clusters():
        for cluster_idx, cluster in cluster_map.items():
            with dot.subgraph(name=f"cluster_{graph_idx}_{cluster_idx}") as c:
                c.attr(style="filled", color="lightgrey")
                for edge in cluster.edges:
                    c.edge(prefix + str(edge[0]), prefix + str(edge[1]), color="black")
                c.attr(label=f"cluster {cluster_idx}")
                if rankdir == "LR":
                    c.attr(rotate="90", rankdir="LR")

    if ver == 0:  # orginalny przelyw
        __show_graph_nodes()
        __show_graph_edges()

    if ver == 1:  # przeplyw pomiedzy clustrami
        cluster_seen = set()

        for idx, node in graph.nodes.items():
            if node.type == "OP" and node.cluster_idx not in cluster_seen:
                nodes_in_cluster = len(cluster_map[node.cluster_idx].nodes)
                name = f"{node.cluster_idx} ({nodes_in_cluster})"
                dot.node(prefix + str(node.cluster_idx), name, fillcolor="orange")
                cluster_seen.add(node.cluster_idx)

        for edge in cluster_links:
            cluster_idx_1 = graph.nodes[edge[0]].cluster_idx
            cluster_idx_2 = graph.nodes[edge[1]].cluster_idx
            dot.edge(
                prefix + str(cluster_idx_1),
                prefix + str(cluster_idx_2),
                color="darkgreen",
                penwidth="3",
            )

    if ver == 2:  # sciezki w clustrach
        __show_clusters()

    if ver == 3:  # pelny przeplyw pomiedzy clustrami
        __show_graph_nodes()

        for edge in cluster_links:
            # FIXME: constraint="false", minlen="2"
            dot.edge(
                prefix + str(edge[0]),
                prefix + str(edge[1]),
                color="darkgreen",
                minlen="3",
                penwidth="3",
            )

        __show_clusters()

    resize_dot(dot)
    dot.engine = "dot"
    return dot


def resize_dot(dot, size_per_element=0.15, min_size=12):
    num_rows = len(dot.body)
    content_size = num_rows * size_per_element
    size = max(min_size, content_size)
    size_str = str(size) + "," + str(size)
    dot.graph_attr.update(size=size_str)
    return size


def show_graph(model, ver=0, path="__tli_debug", input=None):
    # FIXME: warning about 'graphviz'
    if not isinstance(model, Graph):
        graph = get_graph(model, input=input)
    else:
        graph = model
    dot = make_dot(graph, ver=ver, prefix="this")
    dot.render(filename=path)
    os.system(f"rm {path}")
    print("saved to file")


def show_remap(g1, g2, remap, path="__tli_debug", for_edges = False):
    # FIXME: colors? for each cluster?
    # FIXME: show as matrix? A: top-down B: left-right
    dot_g1 = make_dot(g1, ver=3, prefix="src", rankdir="TB")
    dot_g2 = make_dot(g2, ver=3, prefix="dst", rankdir="LR")

    graph_attr = dict(rankdir="TB",)
    dot = Digraph(name="root", graph_attr=graph_attr)
    dot_g2.graph_attr.update(rotate="90")
    ###
    ### dot.graph_attr.update(rank="same", ranksep="5", nodesep="2", pad="2")
    ###
    dot_g2.graph_attr.update(compound="True")
    dot_g1.graph_attr.update(compound="True")
    dot.graph_attr.update(compound="True")  # , peripheries="0")
    dot.subgraph(dot_g2)
    dot.subgraph(dot_g1)
    from matplotlib.colors import to_hex
    import matplotlib.pyplot as plt

    cmap = plt.get_cmap("gist_rainbow")
    if not for_edges:
        arr = g1.cluster_map.keys()
    else:
        arr = range(len(remap.keys()))

    colors = cmap(np.linspace(0, 1, len(arr)))
    colors_map = {}  # FIXME: sorted?
    for (i, color) in zip(arr, colors):
        colors_map[i] = color

    for i, (idx_dst, idx_src) in enumerate(remap.items()):
        if not for_edges:
            color = colors_map[g1.nodes[idx_src].cluster_idx]
        else:
            color = colors_map[i]
        dot.edge(
            "src" + str(idx_src),
            "dst" + str(idx_dst),
            color=to_hex(color),
            # color="red",
            constraint="false",
            penwidth="5",
            weight="5",
        )
    dot.render(filename=path)
    os.system(f"rm {path}")
    print("saved to file")


################################################################################
# Debug
################################################################################

# if __name__ == "__main__":
#     if True:
#         from research_models import get_model_debug, ResNetUNet

#         model_debug_small = get_model_debug(seed=42, channels=3, classes=10)
#         model_debug_large = get_model_debug(seed=3, channels=3, classes=10)
#         model_unet = ResNetUNet(n_class=6)

#         show_graph(model_debug_small, ver=0, path="__tli_figure_1_all")
#         show_graph(model_debug_small, ver=3, path="__tli_figure_1_graph")

#         transfer(model_debug_small, model_debug_large, debug=True)

#         show_graph(model_unet, ver=1, path="__tli_figure_unet")
#         # sys.exit()

#     if False:  # 8, 11, 9
#         model_A = get_model_timm("efficientnet_lite1")
#         model_B = get_model_timm("mnasnet_100")

#     if False:  # 0, 5, 0, 2
#         model_A = get_model_timm("efficientnet_lite1")
#         model_B = get_model_timm("efficientnet_lite0")

#     if True:  # 47, 53, 49, 45, 47, 45
#         model_A = get_model_timm("efficientnet_lite0")
#         model_B = get_model_timm("efficientnet_lite1")

#     if False:  # 9, 9, 4, 2
#         model_A = get_model_timm("efficientnet_lite1")
#         model_B = get_model_timm("efficientnet_lite1")

#     if False:  # 2, 5, 0, 0
#         model_A = get_model_timm("efficientnet_lite0")
#         model_B = get_model_timm("efficientnet_lite0")

#     if False:  # [5, 15, 4] 5
#         model_A = get_model_timm("mixnet_s")
#         model_B = get_model_timm("mixnet_s")

#     if False:  # [83, 77, 85, 78] 82
#         model_A = get_model_timm("mixnet_s")
#         model_B = get_model_timm("mixnet_m")

#     if False:  # [26, 23, 26] 22, 21
#         model_A = get_model_timm("mixnet_m")
#         model_B = get_model_timm("mixnet_s")

#     if False:  # [81, 74, 73, 71, 69] 68, 70, 64
#         model_A = get_model_timm("efficientnet_lite1")
#         model_B = get_model_timm("tf_efficientnet_b0_ap")

#     if False:  # Q: [66, 26, 24, 31, 25, 24, 29,] 30, 24, 18
#         model_A = get_model_timm("tf_efficientnet_b0_ap")
#         model_B = get_model_timm("mnasnet_100")

#     if False: # Q: [76, 61, 60, 58, 57, 57, 62] 57, 57, 55
#         model_A = get_model_timm("mixnet_s")
#         model_B = get_model_timm("mnasnet_100")

#     if False:  # not comparable
#         model_A = get_model_timm("regnetx_002")
#         model_B = get_model_timm("efficientnet_lite0")

#     # FIXME: automatic report

#     transfer(model_A, model_B, debug=True)  # tli sie

    # FIXME: normalize score [0, 1], maybe mean?
    # model_A = get_model_timm("efficientnet_lite0")
    # model_B = get_model_timm("efficientnet_lite1")
    # sim_ab = get_tli_score(model_A, model_B)
    # sim_ba = get_tli_score(model_B, model_A)
    # print(f"sim_ab = {round(sim_ab, 4)} | sim_ba = {round(sim_ba, 4)}")

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting karateclub
  Downloading karateclub-1.3.3.tar.gz (64 kB)
[2K     [38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m64.5/64.5 KB[0m [31m2.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting numpy<1.23.0
  Using cached numpy-1.22.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (16.9 MB)
Collecting networkx<2.7
  Downloading networkx-2.6.3-py3-none-any.whl (1.9 MB)
[2K     [38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.9/1.9 MB[0m [31m28.9 MB/s[0m eta [36m0:00:00[0m31m18.1 MB/s[0m eta [36m0:00:01[0m
Collecting pygsp
  Downloading PyGSP-0.5.1-py2.py3-none-any.whl (1.8 MB)
[2K     [38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m64.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting gensim>=4.0.0
  Downloading gensim-4.3.0-cp38-cp38-m

### **SPRKD STUDENT TRAINING.**

In [37]:
#Generate approximated teacher saddle point region
def generateTeacherApproxSaddleRegion(GPU_SADDLE_POINTS):
    #Average all best Saddle Points
    #Store final saddle point of first teacher
    total_best_solution = GPU_SADDLE_POINTS[0][-1]
    #Iterate over each teacher in the saddle point dictionary
    #Start with 1 as we already have the best saddle point for the first teacher
    for teacher_index in range(1, len(GPU_SADDLE_POINTS.keys())):
        #Store the last saddle point for the given teacher 
        best_teacher_sp = GPU_SADDLE_POINTS[teacher_index][-1]
        #Iterate across the current best solution and best parameter list, adding the best parameters to the best solution parameters
        for current_param, best_teacher_sp_param in zip(total_best_solution, best_teacher_sp):
            #Add the current parameter to the best teacher parameter (create a total, to be averaged)
            current_param = current_param.add(best_teacher_sp_param)
    #Average the total best solution 
    for param_index, accumulated_param in enumerate(total_best_solution):
        total_best_solution[param_index] = accumulated_param/len(GPU_SADDLE_POINTS.keys())
    #Load teacher_model
    globals()["TEACHER_1_MALARIA"] = TEACHER_MALARIA = Learner(MALARIA_TRAIN_DATALOADER, MALARIA_CNN_ARCHITECTURE, metrics = "accuracy", loss_func = nn.CrossEntropyLoss())
    #TEACHER_1_MALARIA.model.load("TEACHER_1_MALARIA")
    #Load student model
    globals()["STUDENT_MALARIA"] = STUDENT_MALARIA = Learner(MALARIA_TRAIN_DATALOADER, MALARIA_CNN_KD_ARCHITECTURE_WR, metrics = "accuracy", loss_func = nn.CrossEntropyLoss())
    #TEACHER_1_MALARIA.learner.load("TEACHER_1_MALARIA.pth")
    #Create copy of state dict to modify and then load
    m_state_dict = TEACHER_1_MALARIA.state_dict()
    with torch.no_grad():
        for param_index, param in enumerate(TEACHER_1_MALARIA.parameters()):
            #Get layer key at that index
            key = list(m_state_dict.keys())[param_index]
            #Set weight
            m_state_dict[key] = total_best_solution[param_index]
    #Apply Transfer Learning by Injection (TLI) - re-initialize student parameters (re-declare model)
    #TLI parameters should be in terms of the student, rather than the student being initialized to the TLI-injected parameters
    #First store original student parameters
    orig_params = []
    [orig_params.append(param.clone().detach()) for param in STUDENT_MALARIA.parameters()]
    #Apply TLI
    apply_tli(STUDENT_MALARIA, teacher = TEACHER_1_MALARIA)
    #Store TLI saddle points
    TLI_INJECTED_SADDLE_POINTS = []
    for param_idx, param in enumerate(STUDENT_MALARIA.parameters()):
        TLI_INJECTED_SADDLE_POINTS.append(param.data.clone().detach())
    #Reset all model parameters
    defineModels()
    #Redeclare student
    globals()["STUDENT_MALARIA"] = STUDENT_MALARIA = Learner(MALARIA_TRAIN_DATALOADER, MALARIA_CNN_KD_ARCHITECTURE_WR, metrics = "accuracy", loss_func = nn.CrossEntropyLoss())
    #Return TLI points
    return orig_params, TLI_INJECTED_SADDLE_POINTS

In [24]:
#Get TLI points
orig_params, TLI_saddles = generateTeacherApproxSaddleRegion(GPU_SADDLE_POINTS = GPU_SADDLE_POINTS)
#Print values to ensure parameters are correct
print("S1")
print([print(param) for param in orig_params]) 
print("TLI")
print([print(param) for param in TLI_saddles]) 
print("S2")
print([print(param) for param in STUDENT_MALARIA.parameters()]) 

API: V2
[E_nodes (1, 32)][E_nodes (22, 5)](encode_graph ended)
[E_nodes (1, 32)][E_nodes (22, 5)](encode_graph ended)
(graph_src ended)
(graph_dst ended)
FIT
EMBEDDING
--------------------> (88, 16)
DATASET (153, 72)
DATASET (154, 72)
DATASET FULL (307, 72)
Iteration 1, loss = 0.15639171
Iteration 2, loss = 0.06710181
Iteration 3, loss = 0.05536314
Iteration 4, loss = 0.04637740
Iteration 5, loss = 0.03403355
Iteration 6, loss = 0.02667419
Iteration 7, loss = 0.02409733
Iteration 8, loss = 0.02045904
Iteration 9, loss = 0.01562859
Iteration 10, loss = 0.01369159
Iteration 11, loss = 0.01329088
Iteration 12, loss = 0.01227812
Iteration 13, loss = 0.01057903
Iteration 14, loss = 0.00971030
Iteration 15, loss = 0.00969079
Iteration 16, loss = 0.00872298
Iteration 17, loss = 0.00728801
Iteration 18, loss = 0.00676438
Iteration 19, loss = 0.00668000
Iteration 20, loss = 0.00631133
Iteration 21, loss = 0.00573331
Iteration 22, loss = 0.00541926
Iteration 23, loss = 0.00519362
Iteration 24, l

In [49]:
#Train student
#Define parameters
n_epochs = 10
#Define loss function  (cross entropy)
Cross_Entropy_Loss = nn.CrossEntropyLoss()
#Train student model
train_student(n_epochs = n_epochs, architecture = MALARIA_CNN_KD_ARCHITECTURE_WR, n_samples_per_epoch = 700, experiment = "MALARIA",
              loss_function = Cross_Entropy_Loss, is_control = False, teacher_saddle_points = TLI_saddles, cooldown_steps = 20, decay = 5, 
              selfKD = False, stride = 1, epsilon = 10e-4, PGD_delta = 2.55, max_hessian_neg_eigensteps = 1000, PGD_epoch_limit = 15)

  tensor_labels = torch.tensor(labels).clone().detach()
  acc = 100 * torch.eq(torch.tensor(predicted_labels), torch.tensor(labels)).sum().item() / batch_size


STEP 1
{}
AVERAGE NORM FROM APPROX. SADDLE REGION: tensor(0.1663, device='cuda:0', grad_fn=<DivBackward0>)

TRANSFORMATION [0] STEP TAKEN.

TRANSFORMATION [0] MATRIX NORM: 405.8822937011719

EUCLIDEAN DIST. [0] MATRIX: tensor([[0.1159, 0.2368, 0.3580],
        [0.2023, 0.4043, 0.2560]], device='cuda:0',
       grad_fn=<DiagonalBackward0>)

EUCLIDEAN DIST. [0] MATRIX NORM: 0.6838157773017883

TRANSFORMATION [1] STEP TAKEN.

TRANSFORMATION [1] MATRIX NORM: 11.908448219299316

EUCLIDEAN DIST. [1] MATRIX: tensor([0.2523, 0.2563], device='cuda:0', grad_fn=<DiagonalBackward0>)

EUCLIDEAN DIST. [1] MATRIX NORM: 0.3596019148826599

TRANSFORMATION [2] STEP TAKEN.

TRANSFORMATION [2] MATRIX NORM: 42.14457321166992

EUCLIDEAN DIST. [2] MATRIX: tensor([[0.2376, 0.3176, 0.1796],
        [0.4507, 0.3429, 0.3509]], device='cuda:0',
       grad_fn=<DiagonalBackward0>)

EUCLIDEAN DIST. [2] MATRIX NORM: 0.7958228588104248

TRANSFORMATION [3] STEP TAKEN.

TRANSFORMATION [3] MATRIX NORM: 4.273544788360596

  tensor_labels = torch.tensor(labels).clone().detach()
  acc = 100 * torch.eq(torch.tensor(predicted_labels), torch.tensor(labels)).sum().item() / batch_size


EPOCH [1/10] (VALIDATION) - LOSS: 0.6931570056411955 ACCURACY: 49.985532407407405 
STEP 324
{0: False, 1: False, 2: False, 3: False, 4: False, 5: False, 6: False, 7: False}
AVERAGE NORM FROM APPROX. SADDLE REGION: tensor(0.0114, device='cuda:0', grad_fn=<DivBackward0>)
STEP 325
{0: False, 1: False, 2: False, 3: False, 4: False, 5: False, 6: False, 7: False}
AVERAGE NORM FROM APPROX. SADDLE REGION: tensor(0.0114, device='cuda:0', grad_fn=<DivBackward0>)
STEP 326
{0: False, 1: False, 2: False, 3: False, 4: False, 5: False, 6: False, 7: False}
AVERAGE NORM FROM APPROX. SADDLE REGION: tensor(0.0115, device='cuda:0', grad_fn=<DivBackward0>)
STEP 327
{0: False, 1: False, 2: False, 3: False, 4: False, 5: False, 6: False, 7: False}
AVERAGE NORM FROM APPROX. SADDLE REGION: tensor(0.0115, device='cuda:0', grad_fn=<DivBackward0>)
STEP 328
{0: False, 1: False, 2: False, 3: False, 4: False, 5: False, 6: False, 7: False}
AVERAGE NORM FROM APPROX. SADDLE REGION: tensor(0.0115, device='cuda:0', grad_f

KeyboardInterrupt: 

In [None]:
#Save model for Hessian Eigen Spectral Decomposition and Loss Landscape Visualization
torch.save(STUDENT_MALARIA, "MODELS/DEC2_NEGHESS_10_SPRKD_MALARIA.pth")
import pickle as pkl
#Save training and validation losses + accuracy lists for further analysis
#Create Dictionary of losses for ease in saving/loading
SPRKD_METRICS = {"TRAINING": {"LOSSES" : STUDENT_MALARIA_TRAIN_LOSSES, "ACCURACIES" : STUDENT_MALARIA_TRAIN_ACCURACIES}, 
                 "VALIDATION" : {"LOSSES" : STUDENT_MALARIA_VALID_LOSSES, "ACCURACIES" : STUDENT_MALARIA_VALID_ACCURACIES}}
with open("METRICS/LOSSES AND ACCURACIES/DEC2_NEGHESS_10_SPRKD_LOSSES.pkl", 'wb') as loss_file:
    pkl.dump(SPRKD_METRICS, loss_file)

In [None]:
#Rudimentary loss visualization (actual visualizations for accuracy, Hessian Eigen Spectral Density, and Loss Landscapes are in an alternate notebook)
import matplotlib.pyplot as plt
%matplotlib inline

print(len(STUDENT_MALARIA_VALID_LOSSES))
print(STUDENT_MALARIA_TRAIN_LOSSES)
plt.figure(figsize = (8, 6))
plt.plot(list(map(float, STUDENT_MALARIA_TRAIN_LOSSES)), label = "TRAIN")
plt.plot(list(map(float, STUDENT_MALARIA_VALID_LOSSES)), label = "VALID")
plt.xlabel("STEP")
plt.ylabel("LOSS")
plt.title("MALARIA CNN ARCHITECTURE - SPRKD DISTILLATION.")
plt.legend()
plt.show()

In [None]:
#Rudimentary accuracy plot (refer to the first comment of the previous cell)
import matplotlib.pyplot as plt
%matplotlib inline

# print(len(STUDENT_CIFAR_100_TRAIN_LOSSES))
plt.figure(figsize = (8, 6))
plt.plot(list(map(float, STUDENT_MALARIA_TRAIN_ACCURACIES[:])), label = "TRAIN")
plt.plot(list(map(float, STUDENT_MALARIA_VALID_ACCURACIES[:])), label = "VALID")
plt.xlabel("STEP")
plt.ylabel("ACCURACY")
plt.title("SIMARD ARCHITECTURE - SPRKD DISTILLATION.")
plt.legend()
plt.show()