# This code demonstrates that the layerwise computation is accurate in the same way that the easy ntk experiment is. We also verify that the easy_NTK gives the same answer as layerwise_NTK, lending credibility to both.

In [1]:
import numpy as np
import torch
import random

import matplotlib.pyplot as plt
%matplotlib inline

import torch
from torch import nn, optim
from torch import load
from torch.nn import functional as F
from torch import autograd

from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader

#from ..easy_ntk import calculate_NTK
#from einops import rearrange

import time

import sys
from pathlib import Path

from numba import njit

In [2]:
import copy

def _del_nested_attr(obj, names):
    """
    Deletes the attribute specified by the given list of names.
    For example, to delete the attribute obj.conv.weight,
    use _del_nested_attr(obj, ['conv', 'weight'])
    """
    if len(names) == 1:
        delattr(obj, names[0])
    else:
        _del_nested_attr(getattr(obj, names[0]), names[1:])

def _set_nested_attr(obj, names, value):
    """
    Set the attribute specified by the given list of names to value.
    For example, to set the attribute obj.conv.weight,
    use _del_nested_attr(obj, ['conv', 'weight'], value)
    """
    if len(names) == 1:
        setattr(obj, names[0], value)
    else:
        _set_nested_attr(getattr(obj, names[0]), names[1:], value)

def extract_weights(mod):
    """
    This function removes all the Parameters from the model and
    return them as a tuple as well as their original attribute names.
    The weights must be re-loaded with `load_weights` before the model
    can be used again.
    Note that this function modifies the model in place and after this
    call, mod.parameters() will be empty.
    """
    orig_params = tuple(mod.parameters())
    # Remove all the parameters in the model
    names = []
    for name, p in list(mod.named_parameters()):
        _del_nested_attr(mod, name.split("."))
        names.append(name)

    # Make params regular Tensors instead of nn.Parameter
    params = tuple(p.detach().requires_grad_() for p in orig_params)
    return params, names

def load_weights(mod, names, params):
    """
    Reload a set of weights so that `mod` can be used again to perform a forward pass.
    Note that the `params` are regular Tensors (that can have history) and so are left
    as Tensors. This means that mod.parameters() will still be empty after this call.
    """
    for name, p in zip(names, params):
        _set_nested_attr(mod, name.split("."), p)
        
def calculate_NTK(model,x,device='cpu',MODE='samples'):
    """
    INPUTS:
        model: torch.nn.Module 
        x: torch.Tensor
        device: 'cpu',
        MODE: 'minima'
    
    OUTPUTS:
        NTK: torch.Tensor
    
    Calculates the NTK for a model, p_dict a state dictionary, and x, a single tensor fed into the model
    
    The NTK is the grammian of the Jacobian of the model output to w.r.t. the weights of the model
    
    This function will output the NTK such that the minima matrix size is used. If the Jacobian is an NxM
    matrix, then the NTK is formulated so that if N < M; NTK is NxN. If M<N, then NTK is MxM.
    
    #EXAMPLE USAGE:
    device='cpu'
    model = MODEL() #a torch.nn.Module object 
    model.to(device)
    state_dict = model.state_dict()

    x_test = np.ones((100,1,28,28),dtype=np.float32)
    x_test = torch.from_numpy(x_test)

    NTK = calculate_NTK(model,x_test)
    """
    if not(MODE in ['minima','samples','params']):
        raise ValueError("MODE must be one of 'minima','samples','params'")
    
    x = x.to(device)
    x.requires_grad=False
    N = x.shape[0]
    M = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    #We need to create a clone of the model or else we make it unusable as part of the trickery 
    #to get pytorch to do what we want. Unforutantely, this exlcludes super big models. but, eh.
    model_clone = copy.deepcopy(model)
    
    params, names = extract_weights(model_clone)
    def model_ntk(*args,model=model_clone, names=names):
        params = tuple(args)
        load_weights(model, names, params)
        return model(x)
    
    Js = torch.autograd.functional.jacobian(model_ntk, tuple(params), create_graph=False, vectorize=True)
    
    Js = list(Js)
    #Js = [element for tupl in Js for element in tupl]
    #collapse the tensors
    for i,tensor in enumerate(Js):
        Js[i] = tensor.reshape(N,-1)
    
    J = torch.cat(Js,axis=1)
    
    if MODE=='minima':
        if N < M: #if datasize points is less than number of parameters:
            NTK = torch.matmul(J,J.T)

        if N >= M:#if number of parameters is less than datasize:
            NTK = torch.matmul(J.T,J)
    elif MODE=='samples':
        NTK = torch.matmul(J,J.T)
    elif MODE=='params':
        NTK = torch.matmul(J.T,J)
    
    return NTK

In [4]:
def relu(X,normalize=False):
    X = F.relu(X)
    if normalize:
        return np.sqrt(2*np.pi/(np.pi-1))*(X-1/np.sqrt(2*np.pi))
    else:
        return X
    

#Identity
def activation(x):
    return x

@njit
def d_activation(x):
    #return 1 
    return np.ones(np.shape(x),dtype=np.float32) #this should be a differnt value...?
    #return np.eye(np.shape(x)[0],np.shape(x)[1])

#Tanh
# def activation(x):
#     return torch.tanh(x)

#@njit
# def d_activation(x):
#     return np.cosh(x)**-2

In [5]:
input_features = 5
hidden_layer = 128
N_datapoints = 100

In [6]:
SEED = 1

In [195]:
@njit
def cross(X):
    return np.dot(np.transpose(X),X)

#@njit
def compute_NTK(Ws, Xs, d_int, d_array):#L counts from 1 to number of layers.
    '''
    I should add some docstring
    
    Ws, a list of the weights as np.array type np.float32,                          [W1, W2, W3 ... W]
    Xs, a list of the conjugate kernels as np.array type np.float32,            [X0, X1, X2, ... XL]
    d_int, a list of the dimensionality of X_l as int64,                        [d0, d1, d2, ... dL]
    d_array, a list of the dimensionality of X_l, as np.array type np.float 32, [d0, d1, d2, ... dL] 
    all of this is neccessary because numba doesnt like type conversion.
    
    outputs the NTK as a np.array of type np.float32
    '''
    L = len(Xs)-1 #number of layers, Xs goes from inputs to right before outputs; X_0 is the input, X_L CK
    n = Xs[0].shape[1] #number of datapoints
    Ds = [np.array([[0.0]],dtype=np.float32)] #holds the derivatives, first value is empty list...?; just a spacer, replace with array
    for l in range(L):
        Ds.append(d_activation(np.dot(Ws[l],Xs[l])))
    KNTK = cross(Xs[L]) #this is eventually summed over
    print(L+1,KNTK)
    for l in range(1,L+1):
        #we are going to construct terms that look like ( S^T S ) * (X^T X)
        XtX = cross(Xs[l-1])
        S = np.zeros((d_int[l],n),dtype=np.float32)
        for i in range(n):
            s = Ws[-1].T.reshape(-1)/np.sqrt(d_array[L])
            for k in range(L,l-1,-1):
                s = Ds[k][:,i]*s
                if k > l:
                    s = np.dot(Ws[k-1],s)/np.sqrt(d_array[k-1])
            S[:,i] = s
        print(l,cross(S)*XtX)
        KNTK += cross(S) * XtX
    return KNTK

In [196]:
def NTK_weights(m):
    if isinstance(m, nn.Linear):
        print(m.weight.shape)
        nn.init.normal_(m.weight.data)#/m.weight.shape[0]
        if m.bias != None:
            nn.init.normal_(m.bias.data)#/m.weight.shape[0]
    if isinstance(m, nn.Conv2d):
        print(m.weight.shape)
        nn.init.normal_(m.weight.data)#/m.weight.shape[0]
        if m.bias != None:
            nn.init.normal_(m.bias.data)#/m.weight.shape[0]

In [197]:
#Layerwise Needs each conjugate Kernel
class dumb_small_layerwise(torch.nn.Module):
    '''
    simple network for test cases
    '''
    def __init__(self,):
        super(dumb_small_layerwise, self).__init__()
        
        self.d1 = torch.nn.Linear(3,2,bias=False)
        self.d2 = torch.nn.Linear(2,2,bias=False)
        self.d3 = torch.nn.Linear(2,1,bias=False)
        
    def forward(self, x_0):
        x_1 = activation(self.d1(x_0)) / np.sqrt(2)
        x_2 = activation(self.d2(x_1)) / np.sqrt(2)
        x_3 = activation(self.d3(x_2))
        
        return x_3, x_2, x_1, x_0
    
# Easy NTK expects one output alone
# class dumb_small(torch.nn.Module):
#     '''
#     simple network for test cases
#     '''
#     def __init__(self,):
#         super(dumb_small, self).__init__()
        
#         self.d1 = torch.nn.Linear(3,2,bias=False)
#         self.d2 = torch.nn.Linear(2,1,bias=False)
        
#     def forward(self, x_0):
#         x_1 = (self.d1(x_0)) / np.sqrt(2)
#         x_2 = self.d2(x_1) 
        
#         return x_2

In [198]:
torch.manual_seed(SEED)
random.seed(SEED)
np.random.seed(SEED)
device='cpu'

model_small = dumb_small()
model_small.to(device)
model_small.apply(NTK_weights)

#Reset the seed and 
torch.manual_seed(SEED)
random.seed(SEED)
np.random.seed(SEED)


model_layerwise = dumb_small_layerwise()
model_layerwise.to(device)
model_layerwise.apply(NTK_weights)

x_test = np.random.normal(0,1,(2,3)).astype(np.float32)
x_test = torch.from_numpy(x_test)



torch.Size([2, 3])
torch.Size([1, 2])
torch.Size([2, 3])
torch.Size([2, 2])
torch.Size([1, 2])


# Now compute 3 ways:

#### by hand

In [199]:
# W1 = model_small.d1.weight.detach().numpy().T
# W2 = model_small.d2.weight.detach().numpy().T

# print(W1.shape)

# A00 = W1[0,0] 
# A01 = W1[0,1]
# A10 = W1[1,0] 
# A11 = W1[1,1]
# A20 = W1[2,0]
# A21 = W1[2,1]

# B00 = W2[0,0]
# B10 = W2[1,0]

# X00 = x_test[0,0]
# X01 = x_test[0,1]
# X02 = x_test[0,2]
# X10 = x_test[1,0]
# X11 = x_test[1,1]
# X12 = x_test[1,2]

# J = np.array([[X00*B00, X01*B00, X02*B00, X00*B10, X01*B10, X02*B10, X00*A00 + X01*A10 + X02*A20, X00*A01 + X01*A11 + X02*A21],
#               [X10*B00, X11*B00, X12*B00, X10*B10, X11*B10, X12*B10, X10*A00 + X11*A10 + X12*A20, X10*A01 + X11*A11 + X12*A21]])

# J = J / np.sqrt(2) #in this 2 layer linear network, this is okay.

# NTK_byhand = (J @ J.T) 

#### Layerwise

In [200]:
x_3, x_2, x_1, x_0 = model_layerwise(x_test)

Ws = []
Ws.append(model_layerwise.d1.weight.detach().numpy().astype(np.float32))
Ws.append(model_layerwise.d2.weight.detach().numpy().astype(np.float32))
Ws.append(model_layerwise.d3.weight.detach().numpy().astype(np.float32))

Xs = [] # Xs are shape (output x #DP) ; however, typical python notation is reversed, so we take transpose here
Xs.append(x_0.detach().numpy().T.astype(np.float32))
Xs.append(x_1.detach().numpy().T.astype(np.float32))
Xs.append(x_2.detach().numpy().T.astype(np.float32))

ds_int = []
ds_int.append(3)
ds_int.append(2)
ds_int.append(2)

ds_array = []
ds_array.append(np.array([3.0],dtype=np.float32)) #first element is the input length
ds_array.append(np.array([2.0],dtype=np.float32))
ds_array.append(np.array([2.0],dtype=np.float32)) #the remaining elements are the output lengths, but omit the last output length assumed 1.

NTK_layerwise = compute_NTK(Ws, Xs, ds_int, ds_array)

3 [[ 0.3927115  -0.6094302 ]
 [-0.6094302   0.99327856]]
1 [[ 0.35391566 -0.11361163]
 [-0.11361163  0.7738312 ]]
2 [[ 0.52566415 -1.0252295 ]
 [-1.0252295   2.053475  ]]


#### Jacobian autograd

In [143]:
#NTK_easy = calculate_NTK(model_small,x_test)

# Compare

In [144]:
#print(NTK_byhand)

In [201]:
print(NTK_layerwise)

[[ 1.2722913 -1.7482712]
 [-1.7482712  3.8205848]]


In [146]:
#print(NTK_easy)

# Can we get pytorch to give us the matrices with the correct kinds of derivative?

In [147]:
model_layerwise.zero_grad()

In [148]:
#grad can only be called on a single element at a time. 
#I think we want to construct the s matrices which are done one column / one x-element at a time
#then we want to multiply them in by the conjugate kernels?
one_element = x_3[0]

In [149]:
one_element.backward()

In [150]:
w3_grad = model_layerwise.d3.weight.grad.detach().numpy()
w2_grad = model_layerwise.d2.weight.grad.detach().numpy()
w1_grad = model_layerwise.d1.weight.grad.detach().numpy()

In [151]:
w3 = model_layerwise.d3.weight.detach().numpy()
w2 = model_layerwise.d2.weight.detach().numpy()
w1 = model_layerwise.d1.weight.detach().numpy()

# Find the CK:

In [152]:
x_2[0] #nice, iterating over this gives the first component of the NTK, the rows of the conjugate kernel

tensor([ 0.3747, -0.5023], grad_fn=<SelectBackward>)

In [153]:
w3_grad #however, we can also just grab them from here

array([[ 0.37473935, -0.5022767 ]], dtype=float32)

# Find the next terms

In [None]:
# so you can either call backward, a number of times by 

In [174]:
w3_grad.reshape(-1).T @ w3_grad.reshape(-1) #matches exactly, conjugate kernel x conjugate kernel.

0.3927115

In [175]:
w2_grad.reshape(-1).T @ w2_grad.reshape(-1) #matches the third component exactly; where is the second?

0.5256641

In [188]:
w1_grad.reshape(-1,n).T.shape

(6,)

In [176]:
w1_grad.reshape(-1).T @ w1_grad.reshape(-1) #this ought to be 0.35391566, what gives.

0.074721396

In [157]:
w3_grad.reshape(-1).T @ w3_grad.reshape(-1) +\
w2_grad.reshape(-1).T @ w2_grad.reshape(-1) +\
w1_grad.reshape(-1).T @ w1_grad.reshape(-1)

0.993097

# verify hypothesis by increasing number of terms...

### also lets us see how this fares on more than one datapoint... it'll be alot of calls to backward...

### https://pytorch.org/docs/stable/generated/torch.Tensor.backward.html

### note to future self, since these are partials pretty sure we can use inputs = w1 and that will make it so the gradients are computed alittle more efficiently

### and if inputs does work like that, then the reallly best thing would be to find a way to get vector gradients working.


### also note pytorch doesn't support autograd on non-scalar values, which means we are forced to call on backwards multiple times

In [220]:
#in the future we would iterate over layers instead of like this...
layer_components_w1 = [] #and we can make these arrays since we will know their size and slice into them
# and we would only do this one array element at a time.
layer_components_w2 = []
layer_components_w3 = []
for output in x_3:
    model_layerwise.zero_grad()
    output.backward(retain_graph=True)
    print(output)

    w3_grad = model_layerwise.d3.weight.grad.detach().numpy()
    print(w3_grad)
    w2_grad = model_layerwise.d2.weight.grad.detach().numpy()
    w1_grad = model_layerwise.d1.weight.grad.detach().numpy()

    layer_components_w1.append(w1_grad.reshape(-1).copy())
    layer_components_w2.append(w2_grad.reshape(-1).copy())
    layer_components_w3.append(w3_grad.reshape(-1).copy())

tensor([0.0987], grad_fn=<UnbindBackward>)
[[ 0.37473935 -0.5022767 ]]
tensor([-0.2933], grad_fn=<UnbindBackward>)
[[-0.4067955   0.90983295]]


In [221]:
layer_components_w1 = np.array(layer_components_w1)
layer_components_w2 = np.array(layer_components_w2)
layer_components_w3 = np.array(layer_components_w3)

In [222]:
layer_components_w3

array([[ 0.37473935, -0.5022767 ],
       [-0.4067955 ,  0.90983295]], dtype=float32)

In [223]:
layer_components_w3

array([[ 0.37473935, -0.5022767 ],
       [-0.4067955 ,  0.90983295]], dtype=float32)

In [224]:
layer_components_w1 @ layer_components_w1.T

array([[ 0.0747214 , -0.02398656],
       [-0.02398656,  0.1633772 ]], dtype=float32)

In [225]:
layer_components_w2 @ layer_components_w2.T

array([[ 0.5256641, -1.0252293],
       [-1.0252293,  2.0534747]], dtype=float32)

In [226]:
layer_components_w3 @ layer_components_w3.T

array([[ 0.3927115 , -0.6094302 ],
       [-0.6094302 ,  0.99327856]], dtype=float32)

In [None]:
3 [[ 0.3927115  -0.6094302 ] #good
 [-0.6094302   0.99327856]]

1 [[ 0.35391566 -0.11361163] #wrong. very confusing.
 [-0.11361163  0.7738312 ]]

2 [[ 0.52566415 -1.0252295 ] #pretty good
 [-1.0252295   2.053475  ]]