In [1]:
#from layerwise_ntk import compute_NTK_CNN
import numpy as np
import random
#import matplotlib.pyplot as plt

import torch
from torch import nn, optim
from torch import load
from torch.nn import functional as F
from torch import autograd

from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader

import time

from numba import njit

In [3]:
SEED = 1
how_many = 3
width = 2

In [4]:
def activation(x):
    return torch.tanh(x)

def d_activationt(x):
    return torch.cosh(x)**-2

# def activation(x):
#     return torch.relu(x)

# def d_activationt(x):
#     x[x>0.0]=1.0
#     x[x<0.0]=0.0
#     return x

In [5]:
def NTK_weights(m):
    if isinstance(m, nn.Linear):
        print(m.weight.shape)
        nn.init.normal_(m.weight.data)#/m.weight.shape[0]
        if m.bias != None:
            nn.init.normal_(m.bias.data)#/m.weight.shape[0]
    if isinstance(m, nn.Conv2d):
        print(m.weight.shape)
        nn.init.normal_(m.weight.data)#/m.weight.shape[0]
        if m.bias != None:
            nn.init.normal_(m.bias.data)#/m.weight.shape[0]

In [6]:
class dumb_small(torch.nn.Module):
    '''
    simple network for test cases
    
    
    It seems like bias vectors aren't trivially added.
    '''
    def __init__(self,):
        super(dumb_small, self).__init__()
        
        self.d1 = torch.nn.Conv2d(1,width,3,stride=1,padding=1,bias=True) #28 -> 28

        self.d2 = torch.nn.Conv2d(width,width,3,stride=1,padding=1,bias=True) #28 -> 28
        
        self.d3 = torch.nn.Conv2d(width,width,3,stride=1,padding=1,bias=True) #28 -> 28
        
        self.d4 = torch.nn.Conv2d(width,width,3,stride=1,padding=1,bias=True) #28 -> 28
        
        self.d5 = torch.nn.Linear(width*28*28,1,bias=True)
        
    def forward(self, x_0):
        x_1 = activation(self.d1(x_0))
        x_2 = activation(self.d2(x_1))
        x_3 = activation(self.d3(x_2))
        x_4 = activation(self.d4(x_3))
        x_5 = x_4.reshape(how_many,-1)
        x_6 = activation(self.d5(x_5))
        return x_6 

class dumb_small_layerwise(torch.nn.Module):
    '''
    simple network for test cases
    
    
    It seems like bias vectors aren't trivially added.
    '''
    def __init__(self,):
        super(dumb_small_layerwise, self).__init__()
        
        self.d1 = torch.nn.Conv2d(1,width,3,stride=1,padding=1,bias=True) #28 -> 28

        self.d2 = torch.nn.Conv2d(width,width,3,stride=1,padding=1,bias=True) #28 -> 28
        
        self.d3 = torch.nn.Conv2d(width,width,3,stride=1,padding=1,bias=True) #28 -> 28
        
        self.d4 = torch.nn.Conv2d(width,width,3,stride=1,padding=1,bias=True) #28 -> 28
        
        self.d5 = torch.nn.Linear(width*28*28,1,bias=True)
        
    def forward(self, x_0):
        x_1 = activation(self.d1(x_0))
        x_2 = activation(self.d2(x_1))
        x_3 = activation(self.d3(x_2))
        x_4 = activation(self.d4(x_3))
        x_5 = x_4.reshape(how_many,-1)
        x_6 = activation(self.d5(x_5))
        return x_6, x_5, x_4, x_3, x_2, x_1, x_0

In [21]:
torch.manual_seed(SEED)
random.seed(SEED)
np.random.seed(SEED)
device='cuda'

model = dumb_small_layerwise()
model.apply(NTK_weights)
model.to(device)

torch.manual_seed(SEED)
random.seed(SEED)
np.random.seed(SEED)

model_2 = dumb_small()
model_2.apply(NTK_weights)

x_test = np.random.normal(0,1,(how_many,1,28,28)).astype(np.float32) #n c_in, h, w
x_test = torch.from_numpy(x_test)


torch.Size([2, 1, 3, 3])
torch.Size([2, 2, 3, 3])
torch.Size([2, 2, 3, 3])
torch.Size([2, 2, 3, 3])
torch.Size([1, 1568])
torch.Size([2, 1, 3, 3])
torch.Size([2, 2, 3, 3])
torch.Size([2, 2, 3, 3])
torch.Size([2, 2, 3, 3])
torch.Size([1, 1568])


# Autograd Method

In [28]:
model_2.zero_grad()
y = model_2(x_test)

In [32]:
#in the future we would iterate over layers instead of like this...
layer_components_w1 = [] 
layer_components_w2 = []
layer_components_w3 = []
layer_components_w4 = []
layer_components_w5 = []

layer_components_b1 = []
layer_components_b2 = []
layer_components_b3 = []
layer_components_b4 = []
layer_components_b5 = []

for i in range(len(y)):
    model_2.zero_grad()
    y = model_2(x_test)
    y[i].backward(retain_graph=True)
    print(y[i])
    #Get the tensors
    w1_grad = model_2.d1.weight.grad.detach().numpy()
    w2_grad = model_2.d2.weight.grad.detach().numpy()
    w3_grad = model_2.d3.weight.grad.detach().numpy()
    w4_grad = model_2.d4.weight.grad.detach().numpy()
    w5_grad = model_2.d5.weight.grad.detach().numpy()
    
    b1_grad = model_2.d1.bias.grad.detach().numpy()
    b2_grad = model_2.d2.bias.grad.detach().numpy()
    b3_grad = model_2.d3.bias.grad.detach().numpy()
    b4_grad = model_2.d4.bias.grad.detach().numpy()
    b5_grad = model_2.d5.bias.grad.detach().numpy()

    #reshape and append. deep copy neccessary or else they are the same objects
    layer_components_w1.append(w1_grad.reshape(-1).copy())
    layer_components_w2.append(w2_grad.reshape(-1).copy())
    layer_components_w3.append(w3_grad.reshape(-1).copy())
    layer_components_w4.append(w4_grad.reshape(-1).copy())
    layer_components_w5.append(w5_grad.reshape(-1).copy())
    
    layer_components_b1.append(b1_grad.reshape(-1).copy())
    layer_components_b2.append(b2_grad.reshape(-1).copy())
    layer_components_b3.append(b3_grad.reshape(-1).copy())
    layer_components_b4.append(b4_grad.reshape(-1).copy())
    layer_components_b5.append(b5_grad.reshape(-1).copy())

tensor([1.0000], grad_fn=<SelectBackward>)
tensor([-1.], grad_fn=<SelectBackward>)
tensor([-1.], grad_fn=<SelectBackward>)


In [33]:
layer_components_w1 = np.array(layer_components_w1)
layer_components_w2 = np.array(layer_components_w2)
layer_components_w3 = np.array(layer_components_w3)
layer_components_w4 = np.array(layer_components_w4)
layer_components_w5 = np.array(layer_components_w5)

layer_components_b1 = np.array(layer_components_b1)
layer_components_b2 = np.array(layer_components_b2)
layer_components_b3 = np.array(layer_components_b3)
layer_components_b4 = np.array(layer_components_b4)
layer_components_b5 = np.array(layer_components_b5)

In [47]:
layer_components_b5.shape

(3, 1)

In [34]:
autograd_NTK = layer_components_w1 @ layer_components_w1.T+\
    layer_components_w2 @ layer_components_w2.T+\
    layer_components_w3 @ layer_components_w3.T+\
    layer_components_w4 @ layer_components_w4.T+\
    layer_components_w5 @ layer_components_w5.T+\
    layer_components_b1 @ layer_components_b1.T+\
    layer_components_b2 @ layer_components_b2.T+\
    layer_components_b3 @ layer_components_b3.T+\
    layer_components_b4 @ layer_components_b4.T+\
    layer_components_b5 @ layer_components_b5.T

In [35]:
autograd_NTK

array([[1.9132348e-07, 0.0000000e+00, 0.0000000e+00],
       [0.0000000e+00, 0.0000000e+00, 0.0000000e+00],
       [0.0000000e+00, 0.0000000e+00, 0.0000000e+00]], dtype=float32)

# Now Layerwise

In [13]:
x_test = x_test.to('cuda')
x_6, x_5, x_4, x_3, x_2, x_1, x_0 = model(x_test)

#These need to be numpy
Ws = []
Ws.append(torch.tensor([0.0],dtype=torch.float32)) 
Ws.append(torch.tensor([0.0],dtype=torch.float32)) 
Ws.append(torch.tensor([0.0],dtype=torch.float32))
Ws.append(torch.tensor([0.0],dtype=torch.float32)) #spacer
Ws.append(torch.tensor([0.0],dtype=torch.float32))
Ws.append(model.d5.weight.detach())

#Kernel Matrices, Need to be numpy
Ks = []
Ks.append(model.d1.weight.detach())
Ks.append(model.d2.weight.detach())
Ks.append(model.d3.weight.detach())
Ks.append(model.d4.weight.detach())
Ks.append(torch.tensor([0.0],dtype=torch.float32)) #spacer
Ks.append(torch.tensor([0.0],dtype=torch.float32))


Xs = [] # Xs are shape (output x #DP) ; however, typical python notation is reversed, so we take transpose here
Xs.append(x_0.T.detach())
Xs.append(x_1.T.detach())
Xs.append(x_2.T.detach())
Xs.append(x_3.T.detach())
Xs.append(x_4.T.detach())
Xs.append(x_5.T.detach())

#This is used to create arrays-- needs to be integer list to play nice with compilers
ds_int = []
ds_int.append(width*3*3) #channels_out * kernel_height * kernel_width
ds_int.append(width*3*3) #channels_out * kernel_height * kernel_width
ds_int.append(width*3*3) #channels_out * kernel_height * kernel_width
ds_int.append(width*3*3) #channels_out * kernel_height * kernel_width
ds_int.append(1) #channels_out * kernel_height * kernel_width
ds_int.append(1) #channels_out * kernel_height * kernel_width

ds_array = [] #this is for the NTK formulation, 
#ds_array.append(torch.tensor([1.0],dtype=torch.float32).to(device)) #first element is a spacer, could be anything.

ds_array.append(torch.tensor([1.0],dtype=torch.float32).to(device)) #first element is a spacer, could be anything.
ds_array.append(torch.tensor([1.0],dtype=torch.float32).to(device)) #The rest, even if you dont use NTK formulation, would be 1
ds_array.append(torch.tensor([1.0],dtype=torch.float32).to(device))
ds_array.append(torch.tensor([1.0],dtype=torch.float32).to(device))
ds_array.append(torch.tensor([1.0],dtype=torch.float32).to(device))
ds_array.append(torch.tensor([1.0],dtype=torch.float32).to(device))

filters = []
filters.append(1)
filters.append(1)
filters.append(1)
filters.append(1)
filters.append(0)
filters.append(0)


padding = []
padding.append(1)
padding.append(1)
padding.append(1)
padding.append(1)
padding.append(0)
padding.append(0)


strides = []
strides.append(1)
strides.append(1)
strides.append(1)
strides.append(1)
strides.append(0)
strides.append(0)


layers=[model.d1,
        model.d2,
        model.d3,
        model.d4,
        0.0,
        model.d5
       ]



In [14]:
@njit
def calc_dw(x,w,b,pad,stride,H_,W_):
    """
    Calculates the derivative of conv(x,w) with respect to w
    
    output is shape:
        [datapoints, in_channels out_filters, kernel_height, kernel_width, out_filters, data_height, data_width
    
    'n f1 f2 c kh kw dh dw -> n (c f1 kh kw) (f2 dh dw)'
    """
    dx, dw, db = None, None, None
    N, C, H, W = x.shape
    F, _, HH, WW = w.shape
    
    #dw = np.zeros((N,F,F,C,HH,WW,H_,W_),dtype=np.float32)
    
    dw = np.zeros((N,C,F,HH,WW,F,H_,W_),dtype=np.float32)

    xp = zero_pad(x,pad)
    #high priority, how to vectorize this operation?
    for n in range(N):
        for f in range(F):
            for i in range(HH): 
                for j in range(WW): 
                    for k in range(H_): 
                        for l in range(W_): 
                            for c in range(C): 
                                dw[n,c,f,i,j,f,k,l] += xp[n, c, i+stride*k, j+stride*l]                             
    
    return dw.reshape((N,(C*F*HH*WW),(F*H_*W_)))

@njit
def calc_dx(x,w,b,pad,stride,H_,W_):
    '''
    calculates the derivative of conv(x,w) with respect to x
    
    output is a nd-array of shape n x ch_in x og_h x og_w x (h_out w_out ch_out)
    '''
    dx, dw, db = None, None, None
    N, C, H, W = x.shape
    F, _, HH, WW = w.shape 

    dx = np.zeros((C,H,W,F,H_,W_,),dtype=np.float32)
    #high priority, how to vectorize this operation? maybe with np.chunk,split?
    for f in range(F): 
        for i in range(H): 
            for j in range(W):
                for k in range(H_): 
                    for l in range(W_):
                        for c in range(C): 
                            if i-stride*k+pad > HH-1 or j-stride*l+pad > WW-1:
                                continue #this is alternative to padding w with zeros.
                            if i-stride*k+pad < 0 or j-stride*l+pad < 0:
                                continue #this is alternative to padding w with zeros.
                            dx[c,i,j,f,k,l] += w[f, c, i-stride*k+pad, j-stride*l+pad]
    #'c ih iw f oh ow -> (c ih iw) (f oh ow)'
    return dx.reshape(((C*H*W),(F*H_*W_)))
    
@njit
def zero_pad(A,pad):
    N, F, H, W = A.shape
    P = np.zeros((N, F, H+2*pad, W+2*pad),dtype=np.float32)
    P[:,:,pad:H+pad,pad:W+pad] = A
    return P
    
@njit
def cross(X):
    return X.T.dot(X)

def cross_pt_nonp(X,device='cuda'):
    X = X.to(device)
    return X.T.matmul(X)

def cross_pt(X,device='cuda'):
    X = torch.from_numpy(X).to(device)
    X = X.to(device)
    return X.T.matmul(X).cpu()

In [15]:
def compute_NTK_CNN(Ws: list, Ks: list, Xs: list, d_int: list, d_array: list, strides: list, padding: list, layers: list, d_activationt, device="cuda", ) -> list:
    '''
    MAIN:
    
    Inputs: 
    Ws: list, has length number of layers + any reshaping, contains dense layer weight tensors, detatched, and on device
    Ks: list, has length number of layers + any reshaping, contains 2d convolutional layers weight tensors, detatched, and on device
    Xs: list, has length number of layers + any reshaping, contains all intermediate outputs of each layer in the models
    d_int: list, has the number of bias parameters in each dense layer in the models
    d_array: list, has the value of which to sqrt and divide by in each layer, typically called the NTK normalization. else, its values are 1.
    strides: list, has the value of stride in each convolutional layer, else 0
    padding: list, has the value of padding in each convolutional layer, else 0
    layers: list, is a list containing the pytorch layers in the model, and "0" as a placeholder for a reshaping layer
    device: str, one of either 'cpu' or 'cuda'; must be the same as the model device location
    
    NOTE: all of the above lists should have the same length! See example
    
    OUTPUTS: list of torch.tensor objects, each the ntk 'component' for that layer. Given in backwards order, i.e. starting with the last layer. First weight, then bias.
    
    NOTE: to get the full ntk, simply sum over the layer dimension of the result: NTK = torch.sum(torch.stack(components),dim=(0))
    '''
    components = []
    
    L = len(Xs)-1 #number of layers, Xs goes from inputs to right before outputs; X_0 is the input, X_L CK
    n = Xs[0].shape[-1] #number of datapoints

    #holds the derivatives of activation, first value is empty list...?; just a spacer, replace with array
    Ds_dense = [np.array([[0.0]],dtype=np.float32)] 
    Ds_conv = [np.array([[0.0]],dtype=np.float32)]
    s_matrices = []
    with torch.no_grad():
        ####################################################################################################
        for l in range(0,L):
            if isinstance(layers[l],torch.nn.Linear):
                Ds_dense.append(d_activationt(layers[l](Xs[l].T)).T)
            else:
                Ds_dense.append(np.array([[0.0]],dtype=np.float32))
        ################################################################################################
        for l in range(0,L):
            if isinstance(layers[l],torch.nn.Conv2d):
                Ds_conv.append(d_activationt(layers[l](Xs[l].T)).reshape(n,-1).T)
            else:
                Ds_conv.append(np.array([[0.0]],dtype=np.float32))      
        ####################################################################################################
        S = torch.tensor([1.0],dtype=torch.float32).to(device) #this models the backward propogation:   
        for l in range(L,-1,-1):
            if isinstance(layers[l], torch.nn.Linear):

                components.append(cross_pt_nonp(S,device)*cross_pt_nonp(Xs[l],device)/d_array[l])

                W = torch.ones((d_int[l],n),dtype=torch.float32).to(device) * S
                components.append(cross_pt_nonp(W,device).to(device)/d_array[l])

            elif isinstance(layers[l], torch.nn.Conv2d):
                if len(S.shape) == 2: #this should only affect the very last layer, at which point, who cares.
                    S = S[None,:,:]
                    
                dw = calc_dw(x=Xs[l].T.cpu().numpy(),w=Ks[l].cpu().numpy(),b=0,pad=padding[l],stride=strides[l],H_=Xs[l+1].shape[1],W_=Xs[l+1].shape[0])
                
                W = torch.matmul(torch.from_numpy(dw).to(device),S.to(device))
                
                #We should bring this to zhichao or sombody and ask if there is obviously something faster?
                #W = np.diagonal(W,0,2,0)
                W = torch.diagonal(W,0,0,2)

                components.append(cross_pt_nonp(W,device).to(device)/d_array[l])

                N = Ks[l].shape[0]
                W = np.split(S.cpu().numpy(),N,axis=1)
                #W = torch.split(S,N,dim=1)
                
                W = np.array(W)
                #W = torch.stack(W)
                
                W = np.sum(W,axis=(1,2))
                #W = torch.sum(W,dim=(1,2))
                
                components.append(torch.from_numpy(cross(W,)).to(device)/d_array[l])
                #components.append(cross_pt_nonp(W,device))

            #############################
            #now we setup S for the next loop by treating appropriately
            if l==0:
                break

            if isinstance(layers[l], torch.nn.Linear):
                S = torch.matmul(S.T,Ws[l]).T / torch.sqrt(d_array[l])
                if len(S.shape) < 2:
                    S = S[:,None] #expand dimension along axis 1
                if not(isinstance(layers[l-1],float)): #this exludes the reshaping layer
                    S = Ds_dense[l]*S
                else: #and when the reshaping layer occurs we need to apply this instead
                    S = Ds_conv[l-1]*S

            elif isinstance(layers[l], torch.nn.Conv2d):
                dx = calc_dx(x=Xs[l].T.cpu().numpy(),w=Ks[l].cpu().numpy(),b=0,pad=padding[l],stride=strides[l],H_=Xs[l+1].shape[1],W_=Xs[l+1].shape[0])
                S = (torch.from_numpy(dx[None,:,:]).to(device) @ S) / torch.sqrt(d_array[l])
                S = Ds_conv[l]*S
            
    return components

In [16]:
ntk_components = compute_NTK_CNN(Ws, Ks, Xs, ds_int, ds_array, strides, padding, layers, d_activationt, device="cuda")

In [17]:
%%timeit
ntk_components = compute_NTK_CNN(Ws, Ks, Xs, ds_int, ds_array, strides, padding, layers, d_activationt, device="cuda")

27.9 ms ± 143 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [18]:
NTK = torch.sum(torch.stack(ntk_components),[0,])

In [19]:
NTK

tensor([[ 59836.4336,   3977.6104,   7510.1753],
        [  3977.6104,  55491.4805, -11823.7793],
        [  7510.1753, -11823.7793,  48860.9531]], device='cuda:0')

In [20]:
autograd_NTK

array([[1.9132348e-07, 0.0000000e+00, 0.0000000e+00],
       [0.0000000e+00, 0.0000000e+00, 0.0000000e+00],
       [0.0000000e+00, 0.0000000e+00, 0.0000000e+00]], dtype=float32)