# The point of our library is to eventually reach something that looks like:

#### NTK = pytorch.Model.NTK(X)

#### So forget about just using numpy to specify all these computations; lets use pytorch operations themselves. Then, we can try using pytorch cuda arrays instead of doing the computaiton on the CPU, which should be our most dramatic speed up.

#### Then, we can work to implement a specific subset of the functions on numba. padding isn't currently implemented on numba, but jnp.padding exists. Also, numba is limited in certain annoying ways. Numba can only work reshapes on contiguous arrays. So all the transposes that I've been using return views of non-contigous arrays it seems. and np.ascontigousarray() doesn't seem to be overriding the output of transpose; neither is copying the array

#### Ultimately, the fastest implementation will likely be on jax; but the entire point of this is that I want it in native pytorch.

In [1]:
import numpy as np
import torch
import random

import matplotlib.pyplot as plt


import torch
from torch import nn, optim
from torch import load
from torch.nn import functional as F
from torch import autograd

from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader

#from ..easy_ntk import calculate_NTK
from einops import rearrange

import time

#import sys
#from pathlib import Path

#from numba import njit
#from numba.typed import List

from numba import njit

%matplotlib inline
%load_ext line_profiler
%load_ext memory_profiler

In [67]:
test = np.array([[1,2,3,4,5],[1,2,3,4,5],[1,2,3,4,5],[1,2,3,4,5],[1,2,3,4,5]])
test = test[None,None,:,:]
N, C, H, W = test.shape
Kh, Kw, = (3, 3)

test.reshape(N, C, H//Kh, Kh, W//Kw, Kw)

ValueError: cannot reshape array of size 25 into shape (1,1,1,3,1,3)

In [97]:
#Numpy CNN function-- checked that it agrees with Tensorflow.
#need to make sure it agrees with PyTorch.


class Conv2d():
    def __init__(self,F,strides=1,padding=0):
        '''
        PYTORCH IMPLEMENTATION!!!
        
        #given an input of x = [batches, length, channels]
        
        #F an input of = [channels_out, channels, kernel_size]
        
        #B an input of [channels_out]
        
        #ouputs a shape: [batches, new_length, channels_out]
        
        #   1. Flattens the filter to a 2-D matrix with shape
        #      `[filter_height * filter_width * in_channels, output_channels]`.
        #   2. Extracts image patches from the input tensor to form a *virtual*
        #      tensor of shape `[batch, out_height, out_width, 
        #      filter_height * filter_width * in_channels]`. batch, new_length, kernel_size * in_channels
        #   3. For each patch, right-multiplies the filter matrix and the image patch
        #      vector.
        '''
        if padding < 0:
            raise ValueError('Padding must be a non-negative int')
        
        self.out_filters = F.shape[0]
        self.in_filters = F.shape[1]
        self.kernel_height = F.shape[2]
        self.kernel_width = F.shape[3]
        self.F = F.T #filters array, now is [width, height, channels_in, channels_out]
        #self.B = B #bias array
        self.strides = strides
        self.padding = padding
        
        self.F = np.reshape(self.F,(-1,self.out_filters))
        
    def forward(self,x):
        batches, channels_in, height, width  = np.shape(x)
        
        if self.padding != 0:
            x = np.pad(x,((0,0),(0,0),(self.padding,self.padding),(self.padding,self.padding)),mode='constant',constant_values=0.0)
    
        new_height = int(((height + 2*self.padding - (self.kernel_height))/self.strides) + 1)
        new_width = int(((width + 2*self.padding - (self.kernel_width))/self.strides) + 1)
        
        dumb_array = np.zeros((batches, new_height, new_width, self.kernel_width * self.kernel_height * channels_in),dtype=np.float32)
        
        for i in range(new_height):
            for j in range(new_width):
                dumb_array[:,i,j,:] = np.reshape(x[:,:,self.strides*i:self.strides*i+self.kernel_height, self.strides*j:self.strides*j+self.kernel_width],(batches, self.kernel_width * self.kernel_height * channels_in),order='F')

        output_array = dumb_array @ self.F
        
        output_array = rearrange(output_array,'b h w f -> b f h w')

        return output_array #+ self.B[None,None,:]

In [98]:
import copy

def _del_nested_attr(obj, names):
    """
    Deletes the attribute specified by the given list of names.
    For example, to delete the attribute obj.conv.weight,
    use _del_nested_attr(obj, ['conv', 'weight'])
    """
    if len(names) == 1:
        delattr(obj, names[0])
    else:
        _del_nested_attr(getattr(obj, names[0]), names[1:])

def _set_nested_attr(obj, names, value):
    """
    Set the attribute specified by the given list of names to value.
    For example, to set the attribute obj.conv.weight,
    use _del_nested_attr(obj, ['conv', 'weight'], value)
    """
    if len(names) == 1:
        setattr(obj, names[0], value)
    else:
        _set_nested_attr(getattr(obj, names[0]), names[1:], value)

def extract_weights(mod):
    """
    This function removes all the Parameters from the model and
    return them as a tuple as well as their original attribute names.
    The weights must be re-loaded with `load_weights` before the model
    can be used again.
    Note that this function modifies the model in place and after this
    call, mod.parameters() will be empty.
    """
    orig_params = tuple(mod.parameters())
    # Remove all the parameters in the model
    names = []
    for name, p in list(mod.named_parameters()):
        _del_nested_attr(mod, name.split("."))
        names.append(name)

    # Make params regular Tensors instead of nn.Parameter
    params = tuple(p.detach().requires_grad_() for p in orig_params)
    return params, names

def load_weights(mod, names, params):
    """
    Reload a set of weights so that `mod` can be used again to perform a forward pass.
    Note that the `params` are regular Tensors (that can have history) and so are left
    as Tensors. This means that mod.parameters() will still be empty after this call.
    """
    for name, p in zip(names, params):
        _set_nested_attr(mod, name.split("."), p)
        
def calculate_NTK(model,x,device='cpu',MODE='samples'):
    """
    INPUTS:
        model: torch.nn.Module 
        x: torch.Tensor
        device: 'cpu',
        MODE: 'minima'
    
    OUTPUTS:
        NTK: torch.Tensor
    
    Calculates the NTK for a model, p_dict a state dictionary, and x, a single tensor fed into the model
    
    The NTK is the grammian of the Jacobian of the model output to w.r.t. the weights of the model
    
    This function will output the NTK such that the minima matrix size is used. If the Jacobian is an NxM
    matrix, then the NTK is formulated so that if N < M; NTK is NxN. If M<N, then NTK is MxM.
    
    #EXAMPLE USAGE:
    device='cpu'
    model = MODEL() #a torch.nn.Module object 
    model.to(device)
    state_dict = model.state_dict()

    x_test = np.ones((100,1,28,28),dtype=np.float32)
    x_test = torch.from_numpy(x_test)

    NTK = calculate_NTK(model,x_test)
    """
    if not(MODE in ['minima','samples','params']):
        raise ValueError("MODE must be one of 'minima','samples','params'")
    
    x = x.to(device)
    x.requires_grad=False
    N = x.shape[0]
    M = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    #We need to create a clone of the model or else we make it unusable as part of the trickery 
    #to get pytorch to do what we want. Unforutantely, this exlcludes super big models. but, eh.
    model_clone = copy.deepcopy(model)
    
    params, names = extract_weights(model_clone)
    def model_ntk(*args,model=model_clone, names=names):
        params = tuple(args)
        load_weights(model, names, params)
        return model(x)
    
    Js = torch.autograd.functional.jacobian(model_ntk, tuple(params), create_graph=False, vectorize=True)
    
    Js = list(Js)
    #Js = [element for tupl in Js for element in tupl]
    #collapse the tensors
    for i,tensor in enumerate(Js):
        Js[i] = tensor.reshape(N,-1)
    
    J = torch.cat(Js,axis=1)
    
    if MODE=='minima':
        if N < M: #if datasize points is less than number of parameters:
            NTK = torch.matmul(J,J.T)

        if N >= M:#if number of parameters is less than datasize:
            NTK = torch.matmul(J.T,J)
    elif MODE=='samples':
        NTK = torch.matmul(J,J.T)
    elif MODE=='params':
        NTK = torch.matmul(J.T,J)
    
    return NTK

In [99]:
@njit
def zero_pad(A,pad):
    N, F, H, W = A.shape
    P = np.zeros((N, F, H+2*pad, W+2*pad),dtype=np.float32)
    P[:,:,pad:H+pad,pad:W+pad] = A
    return P

zero_pad(np.ones((1,1,5,5)),pad=1)

array([[[[0., 0., 0., 0., 0., 0., 0.],
         [0., 1., 1., 1., 1., 1., 0.],
         [0., 1., 1., 1., 1., 1., 0.],
         [0., 1., 1., 1., 1., 1., 0.],
         [0., 1., 1., 1., 1., 1., 0.],
         [0., 1., 1., 1., 1., 1., 0.],
         [0., 0., 0., 0., 0., 0., 0.]]]], dtype=float32)

In [100]:
@njit
def calc_dw(x,w,b,pad,stride,H_,W_):
    """
    Calculates the derivative of conv(x,w) with respect to w
    
    output is shape:
        [datapoints, out_filters, out_filters, in_channels, kernel_height, kernel_width, data_height, data_width
    """
    dx, dw, db = None, None, None
    N, C, H, W = x.shape
    F, _, HH, WW = w.shape
    
    dw = np.zeros((N,F,F,C,HH,WW,H_,W_),dtype=np.float32)

    xp = zero_pad(x,pad)
    #high priority, how to vectorize this operation?
    for n in range(N):
        for f in range(F):
            for i in range(HH): 
                for j in range(WW): 
                    for k in range(H_): 
                        for l in range(W_): 
                            for c in range(C): 
                                dw[n,f,f,c,i,j,k,l] += xp[n, c, i+stride*k, j+stride*l]                             
    return dw

@njit
def calc_dx(x,w,b,pad,stride,H_,W_):
    '''
    calculates the derivative of conv(x,w) with respect to x
    
    output is a nd-array of shape n x ch_in x og_h x og_w x (h_out w_out ch_out)
    '''
    dx, dw, db = None, None, None
    N, C, H, W = x.shape
    F, _, HH, WW = w.shape 

    dx = np.zeros((C,H,W,F,H_,W_,),dtype=np.float32)
    #high priority, how to vectorize this operation? maybe with np.chunk,split?
    for f in range(F): 
        for i in range(H): 
            for j in range(W):
                for k in range(H_): 
                    for l in range(W_):
                        for c in range(C): 
                            if i-stride*k+pad > HH-1 or j-stride*l+pad > WW-1:
                                continue #this is alternative to padding w with zeros.
                            if i-stride*k+pad < 0 or j-stride*l+pad < 0:
                                continue #this is alternative to padding w with zeros.
                            dx[c,i,j,f,k,l] += w[f, c, i-stride*k+pad, j-stride*l+pad]
    return dx 


In [101]:
#%%timeit

#can I speed this up with chunks?
#l=1
#5.9ms vs 20.8 micro seconds; numba is taking care of the for loop vectorization for us.
#for numpy this seems fine. 
#calc_dx(Xs[l].T,w=Ks[l],b=0,pad=padding[l],stride=strides[l],H_=Xs[l+1].shape[1],W_=Xs[l+1].shape[0])

22.4 µs ± 9.62 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [102]:
def relu(X,normalize=False):
    X = F.relu(X)
    if normalize:
        return np.sqrt(2*np.pi/(np.pi-1))*(X-1/np.sqrt(2*np.pi))
    else:
        return X
    

# #Identity
# def activation(x):
#     return x

# @njit
# def d_activation(x):
#     return np.ones(np.shape(x),dtype=np.float32) 


#Tanh
def activation(x):
    return torch.tanh(x)

@njit
def d_activation(x):
    return np.cosh(x)**-2

In [103]:
SEED = 0

In [104]:
def NTK_weights(m):
    if isinstance(m, nn.Linear):
        print(m.weight.shape)
        nn.init.normal_(m.weight.data)#/m.weight.shape[0]
        if m.bias != None:
            nn.init.normal_(m.bias.data)#/m.weight.shape[0]
    if isinstance(m, nn.Conv2d):
        print(m.weight.shape)
        nn.init.normal_(m.weight.data)#/m.weight.shape[0]
        if m.bias != None:
            nn.init.normal_(m.bias.data)#/m.weight.shape[0]

In [105]:
# Easy NTK expects one output alone
class dumb_small(torch.nn.Module):
    '''
    simple network for test cases
    
    
    It seems like bias vectors aren't trivially added.
    '''
    def __init__(self,):
        super(dumb_small, self).__init__()
        
        self.d1 = torch.nn.Conv2d(1,5,3,bias=True)
        
        self.d2 = torch.nn.Conv2d(5,4,3,stride=2,padding=1,bias=True)

        self.d3 = torch.nn.Conv2d(4,3,3,stride=1,padding=0,bias=True)
        
        self.d4 = torch.nn.Linear(2*2*3,3,bias=True)
        
        self.d5 = torch.nn.Linear(3,1,bias=True)
        
    def forward(self, x_0):
        x_1 = activation(self.d1(x_0))
        x_2 = activation(self.d2(x_1))
        x_3 = activation(self.d3(x_2))
        x_4 = x_3.reshape(-1,2*2*3)
        x_5 = activation(self.d4(x_4))
        x_6 = self.d5(x_5)
        return x_6
    
# Easy NTK expects one output alone
class dumb_small_layerwise(torch.nn.Module):
    '''
    simple network for test cases
    
    
    It seems like bias vectors aren't trivially added.
    '''
    def __init__(self,):
        super(dumb_small_layerwise, self).__init__()
        
        self.d1 = torch.nn.Conv2d(1,5,3,bias=True)#10 -> 8

        self.d2 = torch.nn.Conv2d(5,4,3,stride=2,padding=1,bias=True)#8 -> 4
        
        self.d3 = torch.nn.Conv2d(4,3,3,stride=1,padding=0,bias=True)#4 -> 2
        
        self.d4 = torch.nn.Linear(2*2*3,3,bias=True)
        
        self.d5 = torch.nn.Linear(3,1,bias=True)
        
    def forward(self, x_0):
        x_1 = activation(self.d1(x_0))
        x_2 = activation(self.d2(x_1))
        x_3 = activation(self.d3(x_2))
        x_4 = x_3.reshape(-1,2*2*3)
        x_5 = activation(self.d4(x_4))
        x_6 = self.d5(x_5)
        return x_6, x_5, x_4, x_3, x_2, x_1, x_0
    

In [106]:
torch.manual_seed(SEED)
random.seed(SEED)
np.random.seed(SEED)
device='cpu'

model_small = dumb_small()
model_small.to(device)
model_small.apply(NTK_weights)

torch.manual_seed(SEED)
random.seed(SEED)
np.random.seed(SEED)
device='cpu'

model_layerwise = dumb_small_layerwise()
model_layerwise.to(device)
model_layerwise.apply(NTK_weights)

x_test = np.random.normal(0,1,(3,1,9,9)).astype(np.float32) #n c_in, h, w
x_test = torch.from_numpy(x_test)

torch.Size([5, 1, 3, 3])
torch.Size([4, 5, 3, 3])
torch.Size([3, 4, 3, 3])
torch.Size([3, 12])
torch.Size([1, 3])
torch.Size([5, 1, 3, 3])
torch.Size([4, 5, 3, 3])
torch.Size([3, 4, 3, 3])
torch.Size([3, 12])
torch.Size([1, 3])


In [107]:
assert torch.all(model_layerwise.d1.weight == model_small.d1.weight)
assert torch.all(model_layerwise.d2.weight == model_small.d2.weight)

# Easy_NTK

In [108]:
NTK_easy = calculate_NTK(model_small,x_test).detach().numpy()

# Pytorch Autograd

In [109]:
model_small.zero_grad()
y = model_small(x_test)

In [110]:
#this method agrees between model layerwise and model small; meaning that the calculation is indepdent of those
#two models. the insinuation is somehting is wrong with both my methods for calculating,--- the same thing, since
#they agree with one another.

#in the future we would iterate over layers instead of like this...
layer_components_w1 = [] 
layer_components_w2 = []
layer_components_w3 = []
layer_components_w4 = []
layer_components_w5 = []

layer_components_b1 = []
layer_components_b2 = []
layer_components_b3 = []
layer_components_b4 = []
layer_components_b5 = []

for output in y:
    model_small.zero_grad()
    
    output.backward(retain_graph=True)

    #Get the tensors
    w1_grad = model_small.d1.weight.grad.detach().numpy()
    w2_grad = model_small.d2.weight.grad.detach().numpy()
    w3_grad = model_small.d3.weight.grad.detach().numpy()
    w4_grad = model_small.d4.weight.grad.detach().numpy()
    w5_grad = model_small.d5.weight.grad.detach().numpy()
    
    b1_grad = model_small.d1.bias.grad.detach().numpy()
    b2_grad = model_small.d2.bias.grad.detach().numpy()
    b3_grad = model_small.d3.bias.grad.detach().numpy()
    b4_grad = model_small.d4.bias.grad.detach().numpy()
    b5_grad = model_small.d5.bias.grad.detach().numpy()

    #reshape and append. deep copy neccessary or else they are the same objects
    layer_components_w1.append(w1_grad.reshape(-1).copy())
    layer_components_w2.append(w2_grad.reshape(-1).copy())
    layer_components_w3.append(w3_grad.reshape(-1).copy())
    layer_components_w4.append(w4_grad.reshape(-1).copy())
    layer_components_w5.append(w5_grad.reshape(-1).copy())
    
    layer_components_b1.append(b1_grad.reshape(-1).copy())
    layer_components_b2.append(b2_grad.reshape(-1).copy())
    layer_components_b3.append(b3_grad.reshape(-1).copy())
    layer_components_b4.append(b4_grad.reshape(-1).copy())
    layer_components_b5.append(b5_grad.reshape(-1).copy())

In [111]:
layer_components_w1 = np.array(layer_components_w1)
layer_components_w2 = np.array(layer_components_w2)
layer_components_w3 = np.array(layer_components_w3)
layer_components_w4 = np.array(layer_components_w4)
layer_components_w5 = np.array(layer_components_w5)

layer_components_b1 = np.array(layer_components_b1)
layer_components_b2 = np.array(layer_components_b2)
layer_components_b3 = np.array(layer_components_b3)
layer_components_b4 = np.array(layer_components_b4)
layer_components_b5 = np.array(layer_components_b5)

In [112]:
autograd_NTK = layer_components_w1 @ layer_components_w1.T+\
    layer_components_w2 @ layer_components_w2.T+\
    layer_components_w3 @ layer_components_w3.T+\
    layer_components_w4 @ layer_components_w4.T+\
    layer_components_w5 @ layer_components_w5.T+\
    layer_components_b1 @ layer_components_b1.T+\
    layer_components_b2 @ layer_components_b2.T+\
    layer_components_b3 @ layer_components_b3.T+\
    layer_components_b4 @ layer_components_b4.T+\
    layer_components_b5 @ layer_components_b5.T

# Layerwise

In [113]:
model_layerwise.to('cpu')
x_6, x_5, x_4, x_3, x_2, x_1, x_0 = model_layerwise(x_test)

#Dense Weight Matrices
Ws = []
Ws.append(np.array([0.0],dtype=np.float32)) #spacer
Ws.append(np.array([0.0],dtype=np.float32)) #spacer 
Ws.append(np.array([0.0],dtype=np.float32)) #spacer 
Ws.append(np.array([0.0],dtype=np.float32)) #spacer reshape is a layer
Ws.append(model_layerwise.d4.weight.detach().numpy().astype(np.float32))
Ws.append(model_layerwise.d5.weight.detach().numpy().astype(np.float32))


#Kernel Matrices
Ks = []
Ks.append(model_layerwise.d1.weight.detach().numpy().astype(np.float32))
Ks.append(model_layerwise.d2.weight.detach().numpy().astype(np.float32))
Ks.append(model_layerwise.d3.weight.detach().numpy().astype(np.float32))
Ks.append(np.array([0.0],dtype=np.float32)) #spacer
Ks.append(np.array([0.0],dtype=np.float32)) #spacer
Ks.append(np.array([0.0],dtype=np.float32)) #spacer

Xs = [] # Xs are shape (output x #DP) ; however, typical python notation is reversed, so we take transpose here
Xs.append(x_0.detach().numpy().T.astype(np.float32))
Xs.append(x_1.detach().numpy().T.astype(np.float32))
Xs.append(x_2.detach().numpy().T.astype(np.float32))
Xs.append(x_3.detach().numpy().T.astype(np.float32))
Xs.append(x_4.detach().numpy().T.astype(np.float32))
Xs.append(x_5.detach().numpy().T.astype(np.float32))

Bs = []
Bs.append(model_layerwise.d1.bias.detach().numpy().astype(np.float32))
Bs.append(model_layerwise.d2.bias.detach().numpy().astype(np.float32))
Bs.append(model_layerwise.d3.bias.detach().numpy().astype(np.float32))
Bs.append(np.array([0.0],dtype=np.float32))
Bs.append(model_layerwise.d4.bias.detach().numpy().astype(np.float32)[:,None])
Bs.append(model_layerwise.d5.bias.detach().numpy().astype(np.float32)[:,None])

#This is used to create arrays-- needs to be integer list to play nice with compilers
ds_int = []
ds_int.append(0)
ds_int.append(5*3*3) #channels_out * kernel_height * kernel_width
ds_int.append(4*3*3) #channels_out * kernel_height * kernel_width
ds_int.append(3*3*3) #channels_out * kernel_height * kernel_width
ds_int.append(3) #number output features
ds_int.append(1) #number output features

ds_array = [] #this is for the NTK formulation, 
ds_array.append(np.array([1.0],dtype=np.float32)) #first element is a spacer, could be anything.
ds_array.append(np.array([1.0],dtype=np.float32)) #The rest, even if you dont use NTK formulation, would be 1
ds_array.append(np.array([1.0],dtype=np.float32))
ds_array.append(np.array([1.0],dtype=np.float32))
ds_array.append(np.array([1.0],dtype=np.float32))
ds_array.append(np.array([1.0],dtype=np.float32))

padding = []
padding.append(0)
padding.append(1)
padding.append(0)
padding.append(0)
padding.append(0)
padding.append(0)

strides = []
strides.append(1)
strides.append(2)
strides.append(1)
strides.append(0)
strides.append(0)
strides.append(0)

# Improve this algorithm's speed: here is the original

In [114]:
@njit
def cross(X):
    return X.T.dot(X)

In [115]:
def compute_NTK_CNN_w_Bias(Ws, Ks, Xs, Bs, d_int, d_array, strides, padding):
    components = []
    
    L = len(Xs)-1 #number of layers, Xs goes from inputs to right before outputs; X_0 is the input, X_L CK
   
    n = Xs[0].shape[-1] #number of datapoints

    #holds the derivatives of activation, first value is empty list...?; just a spacer, replace with array
    Ds_dense = [np.array([[0.0]],dtype=np.float32)] 
    Ds_conv = [np.array([[0.0]],dtype=np.float32)]
    dws = []
    dxs = []
    ####################################################################################################
    for l in range(0,L):
        if np.all(Ws[l]!=0):
            Ds_dense.append(d_activation(np.dot(Ws[l],Xs[l]) + Bs[l]))
        else:
            Ds_dense.append(np.array([[0.0]],dtype=np.float32))
    ####################################################################################################
    for l in range(0,L):
        if np.all(Ks[l]!=0):
            #Ds_conv.append((d_activation(Bs[l][None,:,None,None] + Conv2d(Ks[l],strides[l],padding[l]).forward(Xs[l].T)).reshape((n,-1))).T )
            Ds_conv.append(rearrange(d_activation(Bs[l][None,:,None,None] + Conv2d(Ks[l],strides[l],padding[l]).forward(Xs[l].T)),'n f h w -> n (f h w)').T )
        else:
            Ds_conv.append(np.array([[0.0]],dtype=np.float32))
    ####################################################################################################        
    for l in range(0,L):
        #!!! will need to be updated with strides, padding...
        if np.all(Ks[l]!=0):
            #dw2 = calc_dw(x=Xs[1].T,w=Ks[1],b=0,pad=0,stride=1,H_=Xs[2].shape[1],W_=Xs[2].shape[0])
            dw = calc_dw(x=Xs[l].T,w=Ks[l],b=0,pad=padding[l],stride=strides[l],H_=Xs[l+1].shape[1],W_=Xs[l+1].shape[0])
            dws.append(rearrange(dw,'n f1 f2 c kh kw dh dw -> n (c f1 kh kw) (f2 dh dw)') )
            #__, f1, f2, c, kh, kw, dH, dW = dw.shape
            #dws.append(dw.reshape((n, c*f1*kh*kw, f2*dH*dW)))
            if l != 0:
                dx = calc_dx(x=Xs[l].T,w=Ks[l],b=0,pad=padding[l],stride=strides[l],H_=Xs[l+1].shape[1],W_=Xs[l+1].shape[0])
                dxs.append(rearrange(dx,'c ih iw f oh ow -> (c ih iw) (f oh ow)')[None,:,:] )
            else:
                dxs.append(np.array([[0.0]],dtype=np.float32))
        else:
            dws.append(np.array([[0.0]],dtype=np.float32))
            dxs.append(np.array([[0.0]],dtype=np.float32))
    ####################################################################################################
    #The first term is just conjugate kernel
    KNTK = cross(Xs[L])
    components.append(cross(Xs[L]))
    
    ###################################################################################################
    for l in range(1,L+1):#l counts layers going forward from 1...
        #we are going to construct terms that look like ( S^T S ) * (X^T X)
        
        #Skip over non Dense Layers, This could be made more rigorous, maybe pass named parameters?
        if len(np.shape(Xs[l-1]))>2:
            continue            
        XtX = cross(Xs[l-1]) #X_3
        S = np.expand_dims(Ws[-1].T.reshape(-1)/np.sqrt(d_array[L]),axis=1) #has shape input to last layer.
        for k in range(L,l-1,-1): #counts backwards from l
            S = Ds_dense[k]*S
            if k > l:
                S = np.dot(S.T,Ws[k-1]).T/np.sqrt(d_array[k-1])
        components.append(cross(S) * XtX)
        KNTK += cross(S) * XtX
    ###################################################################################################
    #Now Bias in Dense Layers
    for l in range(L+1,0,-1):
        if len(np.shape(Xs[l-1]))>2: #skip the convolutional layers
            continue   
        S=1
        for k in range(L,l-1,-1):
            S = np.dot(Ws[k].T/np.sqrt(d_array[k]),S)
            S = S * Ds_dense[k]
        S = np.multiply(np.ones((d_int[l-1],n),dtype=np.float32),S)
        components.append(cross(S))
        KNTK += cross(S)
    ####################################################################################################
    for l in range(1,L+1):
        #Skip over non CNN layers. This could be made more rigorous, but for LeNets it should work
        if len(np.shape(Xs[l]))<=2:
            continue
        #Need to count backwards the Dense layers, since the algorithm is different... 
        S = Ws[-1].T / np.sqrt(d_array[L])
        for k in range(L,l-1,-1):
            if len(np.shape(Xs[k-1]))<=2: #"if k is a dense layer"
                S = S*Ds_dense[k]
                if k > l and not(np.all(Ws[k-1])==0):
                    S = np.dot(S.T,Ws[k-1]).T/np.sqrt(d_array[k-1])
            if len(np.shape(Xs[k-1]))>2: #and this is "if k is a conv layer"
                S = S * Ds_conv[k-1]
                if k-1 > l:
                    S = (dxs[k-2] @ S) / np.sqrt(d_array[k-1]) #this index probably is either one less, or the list is set up funky.
                if k-1 == l:
                    if len(np.shape(S))<=2: #this takes care of the reshape layer on the last convolution
                        S = S[None,:,:]
                    S = (dws[k-2] @ S) #/np.sqrt(d_array[k-1])
                    break
        S = np.diagonal(S,0,2,0)
        components.append(cross(S))
        KNTK += cross(S)
    ####################################################################################################
    #now bias in the convolutional layers
    for l in range(1,L+1):
        if len(np.shape(Xs[l]))<=2:
            continue
        S = Ws[-1].T / np.sqrt(d_array[L])
        
        for k in range(L,l-1,-1):
            if len(np.shape(Xs[k-1]))<=2: #"if k is a dense layer"
                S = S*Ds_dense[k]
                if k > l and not(np.all(Ws[k-1])==0):
                    S = np.dot(S.T,Ws[k-1]).T/np.sqrt(d_array[k-1])
            if len(np.shape(Xs[k-1]))>2: #and this is "if k is a conv layer"
                
                S = S * Ds_conv[k-1]
                if k-1 > l:
                    S = (dxs[k-2] @ S) / np.sqrt(d_array[k-1]) #this index probably is either one less, or the list is set up funky.
                if k-1 == l:
                    if len(np.shape(S))<=2:
                        S = S[None,:,:]
                    N = np.shape(Ks[l-1])[0] #the number of parameters
                    S = np.split(S,N,axis=1)
                    S = np.array(S)
                    S = np.sum(S,axis=(1,2))
                    break
        components.append(cross(S))
        KNTK += cross(S)
        
    return KNTK, components

### Here is the improved one

In [116]:
model_layerwise.to('cuda')
x_test_gpu = x_test.to('cuda')
x_6, x_5, x_4, x_3, x_2, x_1, x_0 = model_layerwise(x_test_gpu)

#Dense Weight Matrices
Wst = []
Wst.append(torch.tensor([0.0],dtype=torch.float32).to('cuda')) #spacer
Wst.append(torch.tensor([0.0],dtype=torch.float32).to('cuda')) #spacer 
Wst.append(torch.tensor([0.0],dtype=torch.float32).to('cuda')) #spacer 
Wst.append(torch.tensor([0.0],dtype=torch.float32).to('cuda')) #spacer reshape is a layer
Wst.append(model_layerwise.d4.weight)
Wst.append(model_layerwise.d5.weight)


#Kernel Matrices
Kst = []
Kst.append(model_layerwise.d1.weight)
Kst.append(model_layerwise.d2.weight)
Kst.append(model_layerwise.d3.weight)
Kst.append(torch.tensor([0.0],dtype=torch.float32).to('cuda')) #spacer
Kst.append(torch.tensor([0.0],dtype=torch.float32).to('cuda')) #spacer
Kst.append(torch.tensor([0.0],dtype=torch.float32).to('cuda')) #spacer

Xst = [] # Xs are shape (output x #DP) ; however, typical python notation is reversed, so we take transpose here
Xst.append(x_0.T)
Xst.append(x_1.T)
Xst.append(x_2.T)
Xst.append(x_3.T)
Xst.append(x_4.T)
Xst.append(x_5.T)

Bst = []
Bst.append(model_layerwise.d1.bias)
Bst.append(model_layerwise.d2.bias)
Bst.append(model_layerwise.d3.bias)
Bst.append(torch.tensor([0.0],dtype=torch.float32).to('cuda'))
Bst.append(model_layerwise.d4.bias)
Bst.append(model_layerwise.d5.bias)

#This is used to create arrays-- needs to be integer list to play nice with compilers
ds_int = []
ds_int.append(0)
ds_int.append(5*3*3) #channels_out * kernel_height * kernel_width
ds_int.append(4*3*3) #channels_out * kernel_height * kernel_width
ds_int.append(3*3*3) #channels_out * kernel_height * kernel_width
ds_int.append(3) #number output features
ds_int.append(1) #number output features

ds_arrayt = [] #this is for the NTK formulation, 
ds_arrayt.append(torch.tensor([1.0],dtype=torch.float32).to('cuda')) #first element is a spacer, could be anything.
ds_arrayt.append(torch.tensor([1.0],dtype=torch.float32).to('cuda')) #The rest, even if you dont use NTK formulation, would be 1
ds_arrayt.append(torch.tensor([1.0],dtype=torch.float32).to('cuda'))
ds_arrayt.append(torch.tensor([1.0],dtype=torch.float32).to('cuda'))
ds_arrayt.append(torch.tensor([1.0],dtype=torch.float32).to('cuda'))
ds_arrayt.append(torch.tensor([1.0],dtype=torch.float32).to('cuda'))

padding = []
padding.append(0)
padding.append(1)
padding.append(0)
padding.append(0)
padding.append(0)
padding.append(0)

strides = []
strides.append(1)
strides.append(2)
strides.append(1)
strides.append(0)
strides.append(0)
strides.append(0)

layers =[model_layerwise.d1,
         model_layerwise.d2,
         model_layerwise.d3,
         0.0,
         model_layerwise.d4,
         model_layerwise.d5]

In [127]:
def cross_pt(X,device='cuda'):
    X = X.to(device)
    return X.T.matmul(X).cpu()

def d_activationt(x):
    return torch.cosh(x)**-2

def calc_dwt(x,w,b,pad,stride,H_,W_,device='cpu'):
    """
    Calculates the derivative of conv(x,w) with respect to w
    """
    dx, dw, db = None, None, None
    N, C, H, W = x.shape
    F, _, HH, WW = w.shape
    dw = torch.zeros((N,F,F,C,HH,WW,H_,W_),dtype=torch.float32).to(device)
    xp = torch.nn.functional.pad(x,(pad,pad,pad,pad))
    for n in range(N):
        for f in range(F):#
            for i in range(HH): #kernel height
                for j in range(WW): #kernel width
                    for k in range(H_): #output height
                        for l in range(W_): #output width
                            for c in range(C): # 
                                dw[n,f,f,c,i,j,k,l] += xp[n, c, i+stride*k, j+stride*l]              
    return dw 

def calc_dxt(x,w,b,pad,stride,H_,W_,device='cpu'):
    '''
    calculates the derivative of conv(x,w) with respect to x
    
    output is a nd-array of shape n x ch_in x og_h x og_w x (h_out w_out ch_out)
    '''
    dx, dw, db = None, None, None
    # Dimensions
    N, C, H, W = x.shape
    F, _, HH, WW = w.shape #F_out, C_in, kernel_height, kernel_width
    dx = torch.zeros((C,H,W,F,H_,W_,),dtype=torch.float32).to(device)
    for f in range(F): #
        for i in range(H): # input data height
            for j in range(W): #input data width
                for k in range(H_): #output data height
                    for l in range(W_):#output data width
                        for c in range(C): #
                            if i-stride*k+pad > HH-1 or j-stride*l+pad > WW-1:
                                continue
                            if i-stride*k+pad < 0 or j-stride*l+pad < 0:
                                continue 
                            dx[c,i,j,f,k,l] += w[f, c, i-stride*k+pad, j-stride*l+pad]
    return dx 

In [124]:
def compute_NTK_CNN_w_Bias_pt(Ws=Wst, Ks=Kst, Xs=Xst, Bs=Bst, d_int=ds_int, d_array=ds_arrayt, strides=strides, padding=padding, dev='cuda'):
    components = []
    
    L = len(Xs)-1 #number of layers, Xs goes from inputs to right before outputs; X_0 is the input, X_L CK
   
    n = Xs[0].shape[-1] #number of datapoints

    #holds the derivatives of activation, first value is empty list...?; just a spacer, replace with array
    Ds_dense = [torch.zeros((1,1),dtype=torch.float32).to(dev)] 
    Ds_conv =  [torch.zeros((1,1),dtype=torch.float32).to(dev)]
    dws = []
    dxs = []
    ####################################################################################################
    for l in range(0,L):
        if torch.all(Ws[l]!=0):
            Ds_dense.append(d_activationt(layers[l](Xs[l].T)).T)
        else:
            Ds_dense.append(torch.tensor([[0.0]],dtype=torch.float32).to(dev))
    ####################################################################################################
    for l in range(0,L):
        if torch.all(Ks[l]!=0): 
            #Ds_conv.append(rearrange(d_activationt(layers[l](Xs[l].T)),'n f h w -> n (f h w)').T)
            Ds_conv.append(d_activationt(layers[l](Xs[l].T).reshape((n,-1))).T)
        else:
            Ds_conv.append(torch.tensor([[0.0]],dtype=torch.float32).to(dev))
    ####################################################################################################        
    for l in range(0,L):
        #!!! will need to be updated with strides, padding...
        if torch.all(Ks[l]!=0):
            #dw2 = calc_dw(x=Xs[1].T,w=Ks[1],b=0,pad=0,stride=1,H_=Xs[2].shape[1],W_=Xs[2].shape[0])
            dw = calc_dwt(x=Xs[l].T,w=Ks[l],b=0,pad=padding[l],stride=strides[l],H_=Xs[l+1].shape[1],W_=Xs[l+1].shape[0],device=dev)
            dws.append(rearrange(dw,'n f1 f2 c kh kw dh dw -> n (c f1 kh kw) (f2 dh dw)') )
            #__, f1, f2, c, kh, kw, dh, dw = dw.shape
            #dws.append(rearrange(dw,'n f1 f2 c kh kw dh dw -> n (c f1 kh kw) (f2 dh dw)') )
            #dws.append(dw.reshape((n, c*f1*kh*kw, f2*dh*dw)))
            if l != 0:
                dx = calc_dxt(x=Xs[l].T,w=Ks[l],b=0,pad=padding[l],stride=strides[l],H_=Xs[l+1].shape[1],W_=Xs[l+1].shape[0],device=dev)
                dxs.append(rearrange(dx,'c ih iw f oh ow -> (c ih iw) (f oh ow)')[None,:,:] )
            else:
                dxs.append(torch.tensor([[0.0]],dtype=torch.float32))
        else:
            dws.append(torch.tensor([[0.0]],dtype=torch.float32).to(dev))
            dxs.append(torch.tensor([[0.0]],dtype=torch.float32).to(dev))
    ####################################################################################################
    #The first term is just conjugate kernel
    KNTK = cross_pt(Xs[L])
    components.append(cross_pt(Xs[L]))
    ###################################################################################################
    for l in range(1,L+1):#l counts layers going forward from 1...
        #we are going to construct terms that look like ( S^T S ) * (X^T X)
        
        #Skip over non Dense Layers, This could be made more rigorous, maybe pass named parameters?
        if len(np.shape(Xs[l-1]))>2:
            continue            
        XtX = cross_pt(Xs[l-1]) #X_3
        S = Ws[-1].T.reshape(-1)[:,None]/torch.sqrt(d_array[L]) #has shape input to last layer.
        for k in range(L,l-1,-1): #counts backwards from l
            S = Ds_dense[k]*S
            if k > l:
                S = torch.matmul(S.T,Ws[k-1]).T/torch.sqrt(d_array[k-1])
        components.append(cross_pt(S) * XtX)
        KNTK += cross_pt(S) * XtX
    ###################################################################################################
    #Now Bias in Dense Layers
    for l in range(L+1,0,-1):
        if len(np.shape(Xs[l-1]))>2: #skip the convolutional layers
            continue   
        S=torch.tensor([1.0],dtype=torch.float32).to(dev)
        for k in range(L,l-1,-1):
            S = torch.matmul(Ws[k].T/torch.sqrt(d_array[k]),S)
            S = (S * Ds_dense[k].T).T
        S = torch.multiply(torch.ones((d_int[l-1],n),dtype=torch.float32).to(dev),S)
        components.append(cross_pt(S))
        KNTK += cross_pt(S)
    ####################################################################################################
    for l in range(1,L+1):
        #Skip over non CNN layers. This could be made more rigorous, but for LeNets it should work
        if len((Xs[l]).shape)<=2:
            continue
        #Need to count backwards the Dense layers, since the algorithm is different... 
        S = Ws[-1].T / torch.sqrt(d_array[L])
        for k in range(L,l-1,-1):
            if len(np.shape(Xs[k-1]))<=2: #"if k is a dense layer"
                S = S*Ds_dense[k]
                if k > l and not(torch.all(Ws[k-1]==0)):
                    S = torch.matmul(S.T,Ws[k-1]).T/torch.sqrt(d_array[k-1])
            if len(np.shape(Xs[k-1]))>2: #and this is "if k is a conv layer"
                S = S * Ds_conv[k-1]
                if k-1 > l:
                    S = (dxs[k-2] @ S) / torch.sqrt(d_array[k-1]) #this index probably is either one less, or the list is set up funky.
                if k-1 == l:
                    if len(np.shape(S))<=2: #this takes care of the reshape layer on the last convolution
                        S = S[None,:,:]
                    S = (dws[k-2] @ S) #/np.sqrt(d_array[k-1])
                    break
        S = torch.diagonal(S,0,0,2) #tensor, offset, axis1, axis2
        components.append(cross_pt(S))
        KNTK += cross_pt(S)
    ####################################################################################################
    #now bias in the convolutional layers
    for l in range(1,L+1):
        if len(np.shape(Xs[l]))<=2:
            continue
        S = Ws[-1].T / torch.sqrt(d_array[L])
        
        for k in range(L,l-1,-1):
            if len(np.shape(Xs[k-1]))<=2: #"if k is a dense layer"
                S = S*Ds_dense[k]
                if k > l and not(torch.all(Ws[k-1]==0)):
                    S = torch.matmul(S.T,Ws[k-1]).T/torch.sqrt(d_array[k-1])
            if len(np.shape(Xs[k-1]))>2: #and this is "if k is a conv layer"
                
                S = S * Ds_conv[k-1]
                if k-1 > l:
                    S = (dxs[k-2] @ S) / torch.sqrt(d_array[k-1]) #this index probably is either one less, or the list is set up funky.
                if k-1 == l:
                    if len(np.shape(S))<=2:
                        S = S[None,:,:]
                    N = Ks[l-1].shape[0] #the number of parameters
                    
                    S = torch.chunk(S,N,dim=1)
                    S = torch.stack(S)
                    S = torch.sum(S,dim=[1,2])
                    break
        components.append(cross_pt(S))
        KNTK += cross_pt(S)
        
    return KNTK, components

In [125]:
%lprun -f compute_NTK_CNN_w_Bias_pt compute_NTK_CNN_w_Bias_pt(Wst, Kst, Xst, Bst, ds_int, ds_arrayt, strides, padding)

TypingError: Failed in nopython mode pipeline (step: nopython frontend)
[1m[1mnon-precise type pyobject[0m
[0m[1mDuring: typing of argument at <ipython-input-123-f0295ac2d280> (12)[0m
[1m
File "<ipython-input-123-f0295ac2d280>", line 12:[0m
[1mdef calc_dwt(x,w,b,pad,stride,H_,W_,device='cpu'):
    <source elided>
    """
[1m    dx, dw, db = None, None, None
[0m    [1m^[0m[0m

This error may have been caused by the following argument(s):
- argument 0: [1mCannot determine Numba type of <class 'torch.Tensor'>[0m
- argument 1: [1mCannot determine Numba type of <class 'torch.nn.parameter.Parameter'>[0m


In [51]:
%lprun -f compute_NTK_CNN_w_Bias compute_NTK_CNN_w_Bias(Ws, Ks, Xs, Bs, ds_int, ds_array, strides, padding)

In [47]:
NTK_layerwise, components = compute_NTK_CNN_w_Bias(Ws, Ks, Xs, Bs, ds_int, ds_array, strides, padding)

# Compare Against One another

In [48]:
NTK_layerwise

array([[  4.314115 ,   1.1284418,  -0.3944409],
       [  1.1284418, 215.32312  ,   9.727457 ],
       [ -0.3944409,   9.727457 ,  25.41427  ]], dtype=float32)

In [49]:
NTK_easy

array([[  4.314119  ,   1.1284422 ,  -0.39444053],
       [  1.1284423 , 215.32306   ,   9.72744   ],
       [ -0.39444053,   9.727441  ,  25.41423   ]], dtype=float32)

In [50]:
autograd_NTK

array([[  4.314119  ,   1.1284423 ,  -0.39444053],
       [  1.1284423 , 215.32307   ,   9.72744   ],
       [ -0.39444053,   9.72744   ,  25.41423   ]], dtype=float32)

In [63]:
print(np.allclose(NTK_layerwise.cpu().detach().numpy(),NTK_easy))
print(np.allclose(NTK_layerwise.detach().numpy(),autograd_NTK))
print(np.allclose(NTK_easy,autograd_NTK))

True
True
True


# Compare Layerwise Components

In [464]:
print(layer_components_w1 @ layer_components_w1.T)
print(layer_components_w2 @ layer_components_w2.T)
print(layer_components_w3 @ layer_components_w3.T)
print(layer_components_w4 @ layer_components_w4.T)
print(layer_components_w5 @ layer_components_w5.T)
print(' ')
print(layer_components_b1 @ layer_components_b1.T)
print(layer_components_b2 @ layer_components_b2.T)
print(layer_components_b3 @ layer_components_b3.T)
print(layer_components_b4 @ layer_components_b4.T)
print(layer_components_b5 @ layer_components_b5.T)

[[ 0.16559178  0.12416241  0.32319853]
 [ 0.12416241 93.88825     1.362006  ]
 [ 0.32319853  1.362006    7.835956  ]]
[[ 1.3036881e-01  4.5428537e-02 -7.8541219e-02]
 [ 4.5428537e-02  4.9995529e+01  1.2862343e-01]
 [-7.8541219e-02  1.2862343e-01  3.1269197e+00]]
[[ 0.02996182  0.03588624 -0.02337419]
 [ 0.03588624 12.629321    0.02550253]
 [-0.02337419  0.02550253  2.1891887 ]]
[[ 2.5031105e-02 -7.2397418e-02 -1.5084391e-02]
 [-7.2397418e-02  3.5128334e+01  2.5679963e+00]
 [-1.5084391e-02  2.5679963e+00  7.1736832e+00]]
[[ 2.9433024  -0.14043024 -1.6482291 ]
 [-0.14043024  0.5340021   0.72886175]
 [-1.6482291   0.72886175  1.9651612 ]]
 
[[8.1973737e-03 9.6674636e-02 6.0739327e-02]
 [9.6674636e-02 1.6477449e+01 3.4951904e+00]
 [6.0739327e-02 3.4951904e+00 1.2404134e+00]]
[[ 0.00824694  0.00372328 -0.01461124]
 [ 0.00372328  1.6762439  -0.15379384]
 [-0.01461124 -0.15379384  0.1838726 ]]
[[ 1.0811921e-03  5.9116962e-03 -3.4120469e-03]
 [ 5.9116962e-03  4.2559952e-01  8.1217535e-05]
 [-3

In [465]:
print(components[2])
print(components[3])
print(components[4])
print(components[1])
print(components[0])

tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])
tensor([[2.3376e-03, 2.9483e-02, 4.8739e-03],
        [2.9483e-02, 3.5683e+00, 5.7297e-01],
        [4.8739e-03, 5.7297e-01, 6.2731e-01]], grad_fn=<MmBackward>)
tensor([[ 0.1656,  0.1242,  0.3232],
        [ 0.1242, 93.8883,  1.3620],
        [ 0.3232,  1.3620,  7.8360]], grad_fn=<MmBackward>)
tensor([[ 2.5031e-02, -7.2398e-02, -1.5085e-02],
        [-7.2398e-02,  3.5128e+01,  2.5680e+00],
        [-1.5085e-02,  2.5680e+00,  7.1737e+00]], grad_fn=<MulBackward0>)
tensor([[ 2.9433, -0.1404, -1.6482],
        [-0.1404,  0.5340,  0.7289],
        [-1.6482,  0.7289,  1.9652]], grad_fn=<MmBackward>)


In [466]:
print(np.allclose(components[0].detach().numpy(), layer_components_w5 @ layer_components_w5.T))
print(np.allclose(components[1].detach().numpy(), layer_components_w4 @ layer_components_w4.T))
print(np.allclose(components[2].detach().numpy(), layer_components_b5 @ layer_components_b5.T))
print(np.allclose(components[3].detach().numpy(), layer_components_b4 @ layer_components_b4.T))

True
True
True
True


In [467]:
print(np.allclose(components[4].detach().numpy(), layer_components_w1 @ layer_components_w1.T,rtol=1e-4))
print(np.allclose(components[5].detach().numpy(), layer_components_w2 @ layer_components_w2.T,rtol=1e-4))
print(np.allclose(components[6].detach().numpy(), layer_components_w3 @ layer_components_w3.T,rtol=1e-4))

True
True
True


In [468]:
print(np.allclose(components[7].detach().numpy(), layer_components_b1 @ layer_components_b1.T,rtol=1e-3))
print(np.allclose(components[8].detach().numpy(), layer_components_b2 @ layer_components_b2.T,rtol=1e-3))
print(np.allclose(components[9].detach().numpy(), layer_components_b3 @ layer_components_b3.T,rtol=1e-2))

True
True
True
