# This code attempts to check the validity of my jacobian layerwise method given a layer with biases. We will take the right answer as given by the 'easy' method, since it just computes straight jacobians

In [1]:
import numpy as np
import torch
import random

import matplotlib.pyplot as plt


import torch
from torch import nn, optim
from torch import load
from torch.nn import functional as F
from torch import autograd

from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader

#from ..easy_ntk import calculate_NTK
#from einops import rearrange

import time

import sys
from pathlib import Path

from numba import njit
from numba.typed import List

%matplotlib inline
%load_ext line_profiler
%load_ext memory_profiler

In [2]:
import copy

def _del_nested_attr(obj, names):
    """
    Deletes the attribute specified by the given list of names.
    For example, to delete the attribute obj.conv.weight,
    use _del_nested_attr(obj, ['conv', 'weight'])
    """
    if len(names) == 1:
        delattr(obj, names[0])
    else:
        _del_nested_attr(getattr(obj, names[0]), names[1:])

def _set_nested_attr(obj, names, value):
    """
    Set the attribute specified by the given list of names to value.
    For example, to set the attribute obj.conv.weight,
    use _del_nested_attr(obj, ['conv', 'weight'], value)
    """
    if len(names) == 1:
        setattr(obj, names[0], value)
    else:
        _set_nested_attr(getattr(obj, names[0]), names[1:], value)

def extract_weights(mod):
    """
    This function removes all the Parameters from the model and
    return them as a tuple as well as their original attribute names.
    The weights must be re-loaded with `load_weights` before the model
    can be used again.
    Note that this function modifies the model in place and after this
    call, mod.parameters() will be empty.
    """
    orig_params = tuple(mod.parameters())
    # Remove all the parameters in the model
    names = []
    for name, p in list(mod.named_parameters()):
        _del_nested_attr(mod, name.split("."))
        names.append(name)

    # Make params regular Tensors instead of nn.Parameter
    params = tuple(p.detach().requires_grad_() for p in orig_params)
    return params, names

def load_weights(mod, names, params):
    """
    Reload a set of weights so that `mod` can be used again to perform a forward pass.
    Note that the `params` are regular Tensors (that can have history) and so are left
    as Tensors. This means that mod.parameters() will still be empty after this call.
    """
    for name, p in zip(names, params):
        _set_nested_attr(mod, name.split("."), p)
        
def calculate_NTK(model,x,device='cpu',MODE='samples'):
    """
    INPUTS:
        model: torch.nn.Module 
        x: torch.Tensor
        device: 'cpu',
        MODE: 'minima'
    
    OUTPUTS:
        NTK: torch.Tensor
    
    Calculates the NTK for a model, p_dict a state dictionary, and x, a single tensor fed into the model
    
    The NTK is the grammian of the Jacobian of the model output to w.r.t. the weights of the model
    
    This function will output the NTK such that the minima matrix size is used. If the Jacobian is an NxM
    matrix, then the NTK is formulated so that if N < M; NTK is NxN. If M<N, then NTK is MxM.
    
    #EXAMPLE USAGE:
    device='cpu'
    model = MODEL() #a torch.nn.Module object 
    model.to(device)
    state_dict = model.state_dict()

    x_test = np.ones((100,1,28,28),dtype=np.float32)
    x_test = torch.from_numpy(x_test)

    NTK = calculate_NTK(model,x_test)
    """
    if not(MODE in ['minima','samples','params']):
        raise ValueError("MODE must be one of 'minima','samples','params'")
    
    x = x.to(device)
    x.requires_grad=False
    N = x.shape[0]
    M = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    #We need to create a clone of the model or else we make it unusable as part of the trickery 
    #to get pytorch to do what we want. Unforutantely, this exlcludes super big models. but, eh.
    model_clone = copy.deepcopy(model)
    
    params, names = extract_weights(model_clone)
    def model_ntk(*args,model=model_clone, names=names):
        params = tuple(args)
        load_weights(model, names, params)
        return model(x)
    
    Js = torch.autograd.functional.jacobian(model_ntk, tuple(params), create_graph=False, vectorize=True)
    
    Js = list(Js)
    #Js = [element for tupl in Js for element in tupl]
    #collapse the tensors
    for i,tensor in enumerate(Js):
        Js[i] = tensor.reshape(N,-1)
    
    J = torch.cat(Js,axis=1)
    
    if MODE=='minima':
        if N < M: #if datasize points is less than number of parameters:
            NTK = torch.matmul(J,J.T)

        if N >= M:#if number of parameters is less than datasize:
            NTK = torch.matmul(J.T,J)
    elif MODE=='samples':
        NTK = torch.matmul(J,J.T)
    elif MODE=='params':
        NTK = torch.matmul(J.T,J)
    
    return NTK

In [3]:
def relu(X,normalize=False):
    X = F.relu(X)
    if normalize:
        return np.sqrt(2*np.pi/(np.pi-1))*(X-1/np.sqrt(2*np.pi))
    else:
        return X
    

# #Identity
def activation(x):
    return x

@njit
def d_activation(x):
    return np.ones(np.shape(x),dtype=np.float32) 


# #Tanh
# def activation(x):
#     return torch.tanh(x)

# @njit
# def d_activation(x):
#     return np.cosh(x)**-2

In [4]:
SEED = 2

In [5]:
def NTK_weights(m):
    if isinstance(m, nn.Linear):
        print(m.weight.shape)
        nn.init.normal_(m.weight.data)#/m.weight.shape[0]
        if m.bias != None:
            nn.init.normal_(m.bias.data)#/m.weight.shape[0]
    if isinstance(m, nn.Conv2d):
        print(m.weight.shape)
        nn.init.normal_(m.weight.data)#/m.weight.shape[0]
        if m.bias != None:
            nn.init.normal_(m.bias.data)#/m.weight.shape[0]

In [6]:
# Easy NTK expects one output alone
class dumb_small(torch.nn.Module):
    '''
    simple network for test cases
    
    
    It seems like bias vectors aren't trivially added.
    '''
    def __init__(self,):
        super(dumb_small, self).__init__()
        
        self.d1 = torch.nn.Linear(2,3)
        self.d2 = torch.nn.Linear(3,3)
        self.d3 = torch.nn.Linear(3,1)
        
    def forward(self, x_0):
        x_1 = activation(self.d1(x_0))
        x_2 = activation(self.d2(x_1))
        x_3 = self.d3(x_2)
        return x_3
    
# Easy NTK expects one output alone
class dumb_small_layerwise(torch.nn.Module):
    '''
    simple network for test cases
    
    
    It seems like bias vectors aren't trivially added.
    '''
    def __init__(self,):
        super(dumb_small_layerwise, self).__init__()
        
        self.d1 = torch.nn.Linear(2,3)
        self.d2 = torch.nn.Linear(3,3)
        self.d3 = torch.nn.Linear(3,1)
        
    def forward(self, x_0):
        x_1 = activation(self.d1(x_0))
        x_2 = activation(self.d2(x_1))
        x_3 = self.d3(x_2)
        return x_3, x_2, x_1, x_0

In [7]:
torch.manual_seed(SEED)
random.seed(SEED)
np.random.seed(SEED)
device='cpu'

model_small = dumb_small()
model_small.to(device)
model_small.apply(NTK_weights)

torch.manual_seed(SEED)
random.seed(SEED)
np.random.seed(SEED)
device='cpu'

model_layerwise = dumb_small_layerwise()
model_layerwise.to(device)
model_layerwise.apply(NTK_weights)

x_test = np.random.normal(0,1,(3,2)).astype(np.float32)
x_test = torch.from_numpy(x_test)

torch.Size([3, 2])
torch.Size([3, 3])
torch.Size([1, 3])
torch.Size([3, 2])
torch.Size([3, 3])
torch.Size([1, 3])


In [8]:
assert torch.all(model_layerwise.d1.weight == model_small.d1.weight)
assert torch.all(model_layerwise.d2.weight == model_small.d2.weight)
assert torch.all(model_layerwise.d3.weight == model_small.d3.weight)

#### Jacobian autograd

In [9]:
NTK_easy = calculate_NTK(model_small,x_test).detach().numpy()

#### By Hand

In [10]:
A = model_small.d1.weight.detach().numpy().T
B = model_small.d2.weight.detach().numpy().T
C = model_small.d3.weight.detach().numpy().T

A00 = A[0,0]
A01 = A[0,1]
A02 = A[0,2]
A10 = A[1,0]
A11 = A[1,1]
A12 = A[1,2]

B00 = B[0,0]
B01 = B[0,1]
B02 = B[0,2]
B10 = B[1,0]
B11 = B[1,1]
B12 = B[1,2]
B20 = B[2,0]
B21 = B[2,1]
B22 = B[2,2]

C00 = C[0,0]
C10 = C[1,0]
C20 = C[2,0]

x_test_np = x_test.numpy()

X00 = x_test_np[0,0]
X01 = x_test_np[0,1]
X10 = x_test_np[1,0]
X11 = x_test_np[1,1]
X20 = x_test_np[2,0]
X21 = x_test_np[2,1]

J = [[
    B00*(X00*A00 + X01*A10) + B10*(X00*A01+X01*A11) +B20*(X00*A02+X01*A12),
    B01*(X00*A00 + X01*A10) + B11*(X00*A01+X01*A11) +B21*(X00*A02+X01*A12),
    B02*(X00*A00 + X01*A10) + B12*(X00*A01+X01*A11) +B22*(X00*A02+X01*A12),
    C00*(X00*A00+X01*A10),
    C10*(X00*A00+X01*A10),
    C20*(X00*A00+X01*A10),
    C00*(X00*A01+X01*A11),
    C10*(X00*A01+X01*A11),
    C20*(X00*A01+X01*A11),
    C00*(X00*A02+X01*A12),
    C10*(X00*A02+X01*A12),
    C20*(X00*A02+X01*A12),
    C00*B00*X00 + C10*B01*X00 + C20*B02*X00,
    C00*B10*X00 + C10*B11*X00 + C20*B12*X00,
    C00*B20*X00 + C10*B21*X00 + C20*B22*X00,
    C00*B00*X01 + C10*B01*X01 + C20*B02*X01,
    C00*B10*X01 + C10*B11*X01 + C20*B12*X01,
    C00*B20*X01 + C10*B21*X01 + C20*B22*X01
    ],
    [
    B00*(X10*A00 + X11*A10) + B10*(X10*A01+X11*A11) +B20*(X10*A02+X11*A12),
    B01*(X10*A00 + X11*A10) + B11*(X10*A01+X11*A11) +B21*(X10*A02+X11*A12),
    B02*(X10*A00 + X11*A10) + B12*(X10*A01+X11*A11) +B22*(X10*A02+X11*A12),
    C00*(X10*A00+X11*A10),
    C10*(X10*A00+X11*A10),
    C20*(X10*A00+X11*A10),
    C00*(X10*A01+X11*A11),
    C10*(X10*A01+X11*A11),
    C20*(X10*A01+X11*A11),
    C00*(X10*A02+X11*A12),
    C10*(X10*A02+X11*A12),
    C20*(X10*A02+X11*A12),
    C00*B00*X10 + C10*B01*X10 + C20*B02*X10,
    C00*B10*X10 + C10*B11*X10 + C20*B12*X10,
    C00*B20*X10 + C10*B21*X10 + C20*B22*X10,
    C00*B00*X11 + C10*B01*X11 + C20*B02*X11,
    C00*B10*X11 + C10*B11*X11 + C20*B12*X11,
    C00*B20*X11 + C10*B21*X11 + C20*B22*X11
    ],
    [
    B00*(X20*A00 + X21*A10) + B10*(X20*A01+X21*A11) +B20*(X20*A02+X21*A12),
    B01*(X20*A00 + X21*A10) + B11*(X20*A01+X21*A11) +B21*(X20*A02+X21*A12),
    B02*(X20*A00 + X21*A10) + B12*(X20*A01+X21*A11) +B22*(X20*A02+X21*A12),
    C00*(X20*A00+X21*A10),
    C10*(X20*A00+X21*A10),
    C20*(X20*A00+X21*A10),
    C00*(X20*A01+X21*A11),
    C10*(X20*A01+X21*A11),
    C20*(X20*A01+X21*A11),
    C00*(X20*A02+X21*A12),
    C10*(X20*A02+X21*A12),
    C20*(X20*A02+X21*A12),
    C00*B00*X20 + C10*B01*X20 + C20*B02*X20,
    C00*B10*X20 + C10*B11*X20 + C20*B12*X20,
    C00*B20*X20 + C10*B21*X20 + C20*B22*X20,
    C00*B00*X21 + C10*B01*X21 + C20*B02*X21,
    C00*B10*X21 + C10*B11*X21 + C20*B12*X21,
    C00*B20*X21 + C10*B21*X21 + C20*B22*X21
    ]]

J = np.array(J)

J @ J.T

array([[ 3.973685, 15.353924, 18.65885 ],
       [15.353924, 85.88938 , 63.837574],
       [18.65885 , 63.837574, 90.18205 ]], dtype=float32)

In [11]:
JC = [[
    B00*(X00*A00 + X01*A10) + B10*(X00*A01+X01*A11) +B20*(X00*A02+X01*A12),
    B01*(X00*A00 + X01*A10) + B11*(X00*A01+X01*A11) +B21*(X00*A02+X01*A12),
    B02*(X00*A00 + X01*A10) + B12*(X00*A01+X01*A11) +B22*(X00*A02+X01*A12)
    ],
    [
    B00*(X10*A00 + X11*A10) + B10*(X10*A01+X11*A11) +B20*(X10*A02+X11*A12),
    B01*(X10*A00 + X11*A10) + B11*(X10*A01+X11*A11) +B21*(X10*A02+X11*A12),
    B02*(X10*A00 + X11*A10) + B12*(X10*A01+X11*A11) +B22*(X10*A02+X11*A12)    
    ],
    [
    B00*(X20*A00 + X21*A10) + B10*(X20*A01+X21*A11) +B20*(X20*A02+X21*A12),
    B01*(X20*A00 + X21*A10) + B11*(X20*A01+X21*A11) +B21*(X20*A02+X21*A12),
    B02*(X20*A00 + X21*A10) + B12*(X20*A01+X21*A11) +B22*(X20*A02+X21*A12)    
    ]]

JB = [[
    C00*(X00*A00+X01*A10),
    C10*(X00*A00+X01*A10),
    C20*(X00*A00+X01*A10),
    C00*(X00*A01+X01*A11),
    C10*(X00*A01+X01*A11),
    C20*(X00*A01+X01*A11),
    C00*(X00*A02+X01*A12),
    C10*(X00*A02+X01*A12),
    C20*(X00*A02+X01*A12)
    ],
    [
    C00*(X10*A00+X11*A10),
    C10*(X10*A00+X11*A10),
    C20*(X10*A00+X11*A10),
    C00*(X10*A01+X11*A11),
    C10*(X10*A01+X11*A11),
    C20*(X10*A01+X11*A11),
    C00*(X10*A02+X11*A12),
    C10*(X10*A02+X11*A12),
    C20*(X10*A02+X11*A12)    
    ],
    [
    C00*(X20*A00+X21*A10),
    C10*(X20*A00+X21*A10),
    C20*(X20*A00+X21*A10),
    C00*(X20*A01+X21*A11),
    C10*(X20*A01+X21*A11),
    C20*(X20*A01+X21*A11),
    C00*(X20*A02+X21*A12),
    C10*(X20*A02+X21*A12),
    C20*(X20*A02+X21*A12)   
    ]]

JA = [[
    C00*B00*X00 + C10*B01*X00 + C20*B02*X00,
    C00*B10*X00 + C10*B11*X00 + C20*B12*X00,
    C00*B20*X00 + C10*B21*X00 + C20*B22*X00,
    C00*B00*X01 + C10*B01*X01 + C20*B02*X01,
    C00*B10*X01 + C10*B11*X01 + C20*B12*X01,
    C00*B20*X01 + C10*B21*X01 + C20*B22*X01
    ],
    [
    C00*B00*X10 + C10*B01*X10 + C20*B02*X10,
    C00*B10*X10 + C10*B11*X10 + C20*B12*X10,
    C00*B20*X10 + C10*B21*X10 + C20*B22*X10,
    C00*B00*X11 + C10*B01*X11 + C20*B02*X11,
    C00*B10*X11 + C10*B11*X11 + C20*B12*X11,
    C00*B20*X11 + C10*B21*X11 + C20*B22*X11    
    ],
    [
    C00*B00*X20 + C10*B01*X20 + C20*B02*X20,
    C00*B10*X20 + C10*B11*X20 + C20*B12*X20,
    C00*B20*X20 + C10*B21*X20 + C20*B22*X20,
    C00*B00*X21 + C10*B01*X21 + C20*B02*X21,
    C00*B10*X21 + C10*B11*X21 + C20*B12*X21,
    C00*B20*X21 + C10*B21*X21 + C20*B22*X21    
    ]]

JA = np.array(JA)

JB = np.array(JB)

JC = np.array(JC)

print(JA @ JA.T)
print(JB @ JB.T)
print(JC @ JC.T)

[[ 1.1859366  5.351098   5.3296876]
 [ 5.351098  48.64249   16.432072 ]
 [ 5.3296876 16.432072  26.319826 ]]
[[ 0.9833354  3.3494906  4.7572618]
 [ 3.3494906 11.82553   16.075016 ]
 [ 4.7572618 16.075016  23.055311 ]]
[[ 1.8044131  6.6533346  8.571902 ]
 [ 6.6533346 25.421356  31.330482 ]
 [ 8.571902  31.330482  40.806915 ]]


# Can we get pytorch to give us the matrices with the correct kinds of derivative?

In [12]:
model_small.zero_grad()
x_4 = model_small(x_test)

In [13]:
#this method agrees between model layerwise and model small; meaning that the calculation is indepdent of those
#two models. the insinuation is somehting is wrong with both my methods for calculating,--- the same thing, since
#they agree with one another.

#in the future we would iterate over layers instead of like this...
layer_components_w1 = [] 
layer_components_w2 = []
layer_components_w3 = []



for output in x_4:
    model_small.zero_grad()
    
    output.backward(retain_graph=True)

    #Get the tensors

    w3_grad = model_small.d3.weight.grad.detach().numpy()
    w2_grad = model_small.d2.weight.grad.detach().numpy()
    w1_grad = model_small.d1.weight.grad.detach().numpy()


    #reshape and append. deep copy neccessary or else they are the same objects
    layer_components_w1.append(w1_grad.reshape(-1).copy())
    layer_components_w2.append(w2_grad.reshape(-1).copy())
    layer_components_w3.append(w3_grad.reshape(-1).copy())


In [14]:
layer_components_w1 = np.array(layer_components_w1)
layer_components_w2 = np.array(layer_components_w2)
layer_components_w3 = np.array(layer_components_w3)

In [15]:
autograd_NTK = layer_components_w1 @ layer_components_w1.T+\
    layer_components_w2 @ layer_components_w2.T+\
    layer_components_w3 @ layer_components_w3.T

# Layerwise exact

In [23]:
@njit #no parallel transformation available ;#fasterer
def cross(X):
    return X.T.dot(X)

def compute_NTK_new(Ws, Xs, d_int, d_array):
    L = len(Xs)-1 #number of layers, Xs goes from inputs to right before outputs; X_0 is the input, X_L CK
    components = []
    n = Xs[0].shape[1] #number of datapoints

    #holds the derivatives, first value is empty list...?; just a spacer, replace with array
    Ds = [np.array([[0.0]],dtype=np.float32)] 
    for l in range(L):
        Ds.append(d_activation(np.dot(Ws[l],Xs[l])))
    #The first term is just conjugate kernel
    KNTK = cross(Xs[L])
    components.append(cross(Xs[L]))
    ####################
    for l in range(1,L+1):#l counts layers going forward from 1...
        #we are going to construct terms that look like ( S^T S ) * (X^T X)
        XtX = cross(Xs[l-1])
        S = np.expand_dims(Ws[-1].T.reshape(-1)/np.sqrt(d_array[L]),axis=1) #has shape input to last layer.
        for k in range(L,l-1,-1): #counts backwards from l
            S = Ds[k]*S
            if k > l:
                #Whenever we enter this, we pick up a scalar error on the component
                #though the source may not neccessarily be here.
                S = np.dot(S.T,Ws[k-1]).T/np.sqrt(d_array[k-1])
        components.append(cross(S) * XtX)
        KNTK += cross(S) * XtX
    return KNTK, components

# def compute_NTK(Ws, Xs, d_int, d_array):#L counts from 1 to number of layers.
#     '''
#     I should add some docstring
    
#     Ws, a list of the weights as np.array type np.float32,                          [W1, W2, W3 ... W]
#     Xs, a list of the conjugate kernels as np.array type np.float32,            [X0, X1, X2, ... XL]
#     d_int, a list of the dimensionality of X_l as int64,                        [d0, d1, d2, ... dL]
#     d_array, a list of the dimensionality of X_l, as np.array type np.float 32, [d0, d1, d2, ... dL] 
#     all of this is neccessary because numba doesnt like type conversion.
    
#     outputs the NTK as a np.array of type np.float32
#     '''
#     L = len(Xs)-1 #number of layers, Xs goes from inputs to right before outputs; X_0 is the input, X_L CK
#     components= []
#     n = Xs[0].shape[1] #number of datapoints
    
#     #holds the derivatives, first value is empty list...?; just a spacer, replace with array
#     Ds = [np.array([[0.0]],dtype=np.float32)] 
#     for l in range(L):
#         Ds.append(d_activation(np.dot(Ws[l],Xs[l])))
    
#     #The first term is just conjugate kernel
#     KNTK = cross(Xs[L])
#     components.append(cross(Xs[L]))
#     for l in range(1,L+1):#l counts layers going forward from 1...
#         #we are going to construct terms that look like ( S^T S ) * (X^T X)
#         XtX = cross(Xs[l-1])
#         S = np.zeros((d_int[l],n),dtype=np.float32)
#         for i in range(n):
            
#             #this is always the same, could be calculated once and saved
#             s = Ws[-1].T.reshape(-1)/np.sqrt(d_array[L]) #has shape input to last layer.

#             for k in range(L,l-1,-1): #counts backwards from L to l, inclusive both
#                 s = np.diag(Ds[k][:,i]) @ s
#                 if k > l:

#                     s = np.dot(s,Ws[k-1])/np.sqrt(d_array[k-1])

#             S[:,i] = s
        
#         components.append(cross(S) * XtX)
#         KNTK += cross(S) * XtX
#     return KNTK, components

In [24]:
x_3, x_2, x_1, x_0 = model_layerwise(x_test)

Ws = []
Ws.append(model_layerwise.d1.weight.detach().numpy().astype(np.float32))
Ws.append(model_layerwise.d2.weight.detach().numpy().astype(np.float32))
Ws.append(model_layerwise.d3.weight.detach().numpy().astype(np.float32))

Bs = []
Bs.append(model_layerwise.d1.bias.detach().numpy().astype(np.float32))
Bs.append(model_layerwise.d2.bias.detach().numpy().astype(np.float32))
Bs.append(model_layerwise.d3.bias.detach().numpy().astype(np.float32))

Xs = [] # Xs are shape (output x #DP) ; however, typical python notation is reversed, so we take transpose here
Xs.append(x_0.detach().numpy().T.astype(np.float32))
Xs.append(x_1.detach().numpy().T.astype(np.float32))
Xs.append(x_2.detach().numpy().T.astype(np.float32))

ds_int = []
ds_int.append(0)
ds_int.append(3)
ds_int.append(3)

ds_array = []
ds_array.append(np.array([0.0],dtype=np.float32)) #first element is the input length
ds_array.append(np.array([1.0],dtype=np.float32))
ds_array.append(np.array([1.0],dtype=np.float32)) #the remaining elements are the output lengths, but omit the last output length assumed 1.
ds_array.append(np.array([1.0],dtype=np.float32))

In [25]:
layerwise_NTK, components = compute_NTK_new_WithBias(Ws, Bs, Xs, ds_int, ds_array)

# Compare all against each other

In [26]:
layerwise_NTK

array([[160.76213, 193.00449, 200.97227],
       [193.00449, 284.4021 , 267.01312],
       [200.97227, 267.01312, 298.02045]], dtype=float32)

In [27]:
NTK_easy

array([[ 54.521683,  86.76405 ,  94.73182 ],
       [ 86.76405 , 178.16164 , 160.77267 ],
       [ 94.73183 , 160.77267 , 191.78003 ]], dtype=float32)

In [28]:
autograd_NTK

array([[ 45.40515 ,  77.647514,  85.615295],
       [ 77.647514, 169.0451  , 151.65613 ],
       [ 85.615295, 151.65613 , 182.66347 ]], dtype=float32)

In [29]:
J @ J.T

array([[ 3.973685, 15.353924, 18.65885 ],
       [15.353924, 85.88938 , 63.837574],
       [18.65885 , 63.837574, 90.18205 ]], dtype=float32)

# Compare components against one another

In [59]:
print(layer_components_w1 @ layer_components_w1.T)
print(layer_components_w2 @ layer_components_w2.T)
print(layer_components_w3 @ layer_components_w3.T)


[[  2.875438  12.974344  12.922433]
 [ 12.974344 117.93925   39.841427]
 [ 12.922433  39.841427  63.815407]]
[[  0.74225026   2.4397616    3.6184444 ]
 [  2.4397616   79.9333     -10.463793  ]
 [  3.6184444  -10.463793    24.590609  ]]
[[  0.5682735  -3.9876812   4.590777 ]
 [ -3.9876812  44.74071   -37.424416 ]
 [  4.590777  -37.424416   38.706203 ]]


In [60]:
print(components[1])
print(components[2])
print(components[0])

[[  2.8754382  12.974345   12.922433 ]
 [ 12.974345  117.939255   39.841427 ]
 [ 12.922433   39.841427   63.815407 ]]
[[  0.74225026   2.4397616    3.6184444 ]
 [  2.4397616   79.933304   -10.463793  ]
 [  3.6184444  -10.463793    24.590607  ]]
[[  0.5682735  -3.9876812   4.590777 ]
 [ -3.9876812  44.74071   -37.424416 ]
 [  4.590777  -37.424416   38.706203 ]]


In [61]:
print(JA @ JA.T)
print(JB @ JB.T)
print(JC @ JC.T)

[[  2.8754382  12.974345   12.922433 ]
 [ 12.974345  117.93924    39.841423 ]
 [ 12.922433   39.841423   63.8154   ]]
[[  0.7422503   2.4397616   3.6184447]
 [  2.4397616  79.9333    -10.463793 ]
 [  3.6184447 -10.463793   24.590609 ]]
[[  0.5682735  -3.9876812   4.590777 ]
 [ -3.9876812  44.74071   -37.424416 ]
 [  4.590777  -37.424416   38.706203 ]]


In [62]:
#off by a scalar?, and this sclar is dependent on the seed, meaning is depdendent on some property of the weights
np.array(components[1]) / (layer_components_w1 @ layer_components_w1.T)

array([[1.0000001, 1.0000001, 1.       ],
       [1.0000001, 1.0000001, 1.       ],
       [1.       , 1.       , 1.       ]], dtype=float32)

# Conclusions, very sure that NTK_easy and Layerwise_autograd are consistent with something done by hand. 

# The layerwise calcuations components are off by a scalar. if we modulate the number of nodes in a layer, and it changes, then we might be able to figure out what is going wrong. or is it dependent on some property of weight matrix?