# This notebook first demonstrates the basic principle behind our method through autograd, then uses analytic derivatives to find the ntk in parallel

# the point is that for simple architectures the autograd NTK agrees with the analytic method up to a rtol typically of 1e-2 , visually inspecting shows that elements are indeed very close.

# We can also use this notebook to benchmark the CNN ntk-- though note the expected time for autograd NTK to finish for large network or networks with many points is also large, consider skipping these cells.

In [2]:
#from layerwise_ntk import compute_NTK_CNN
import numpy as np
import random

import torch
from torch import nn, optim
from torch import load
from torch.nn import functional as F
from torch import autograd

from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader

import time

from numba import njit

In [1]:
from easy_ntk import compute_NTK_CNN

In [3]:
SEED = 1
how_many = 10
width = 8

In [4]:
def activation(x):
    return torch.tanh(x)

def d_activationt(x):
    return torch.cosh(x)**-2

In [5]:
def NTK_weights(m):
    if isinstance(m, nn.Linear):
        print(m.weight.shape)
        nn.init.normal_(m.weight.data)#/m.weight.shape[0]
        if m.bias != None:
            nn.init.normal_(m.bias.data)#/m.weight.shape[0]
    if isinstance(m, nn.Conv2d):
        print(m.weight.shape)
        nn.init.normal_(m.weight.data)#/m.weight.shape[0]
        if m.bias != None:
            nn.init.normal_(m.bias.data)#/m.weight.shape[0]

In [6]:
class dumb_small(torch.nn.Module):
    '''

    '''
    def __init__(self,):
        super(dumb_small, self).__init__()
        
        self.d1 = torch.nn.Conv2d(1,width,3,stride=1,padding=1,bias=True) #28 -> 28

        self.d2 = torch.nn.Conv2d(width,width,3,stride=1,padding=1,bias=True) #28 -> 28
        
        self.d3 = torch.nn.Conv2d(width,width,3,stride=1,padding=1,bias=True) #28 -> 28
        
        self.d4 = torch.nn.Conv2d(width,width,3,stride=1,padding=1,bias=True) #28 -> 28
        
        self.d5 = torch.nn.Linear(width*28*28,1,bias=True)
        
    def forward(self, x_0):
        x_1 = activation(self.d1(x_0))
        x_2 = activation(self.d2(x_1))
        x_3 = activation(self.d3(x_2))
        x_4 = activation(self.d4(x_3))
        x_5 = x_4.reshape(how_many,-1)
        x_6 = self.d5(x_5)
        return x_6 

class dumb_small_layerwise(torch.nn.Module):
    '''
    NOTE: no activation function on the final layer!
    '''
    def __init__(self,):
        super(dumb_small_layerwise, self).__init__()
        
        self.d1 = torch.nn.Conv2d(1,width,3,stride=1,padding=1,bias=True) #28 -> 28

        self.d2 = torch.nn.Conv2d(width,width,3,stride=1,padding=1,bias=True) #28 -> 28
        
        self.d3 = torch.nn.Conv2d(width,width,3,stride=1,padding=1,bias=True) #28 -> 28
        
        self.d4 = torch.nn.Conv2d(width,width,3,stride=1,padding=1,bias=True) #28 -> 28
        
        self.d5 = torch.nn.Linear(width*28*28,1,bias=True)
        
    def forward(self, x_0):
        x_1 = activation(self.d1(x_0))
        x_2 = activation(self.d2(x_1))
        x_3 = activation(self.d3(x_2))
        x_4 = activation(self.d4(x_3))
        x_5 = x_4.reshape(how_many,-1)
        x_6 = self.d5(x_5)
        return x_6, x_5, x_4, x_3, x_2, x_1, x_0

In [7]:
torch.manual_seed(SEED)
random.seed(SEED)
np.random.seed(SEED)
device='cpu'

model = dumb_small_layerwise()
model.apply(NTK_weights)
model.to(device)

torch.manual_seed(SEED)
random.seed(SEED)
np.random.seed(SEED)

model_2 = dumb_small()
model_2.apply(NTK_weights)

x_test = np.random.normal(0,1,(how_many,1,28,28)).astype(np.float32) #n c_in, h, w
x_test = torch.from_numpy(x_test)


torch.Size([8, 1, 3, 3])
torch.Size([8, 8, 3, 3])
torch.Size([8, 8, 3, 3])
torch.Size([8, 8, 3, 3])
torch.Size([1, 6272])
torch.Size([8, 1, 3, 3])
torch.Size([8, 8, 3, 3])
torch.Size([8, 8, 3, 3])
torch.Size([8, 8, 3, 3])
torch.Size([1, 6272])


# Autograd NTK-- uncomment and run if number used and number of fitlers are small, like datapoints under 100 and filters under 8, in order to test that the result and the result of the easy_ntk layerwise algorithm agree with oneanother. 

# we use this method because it most clearly exposes the NTK calculation to the reader. I'm litterally asking for pytorch to calculate the first derivatives of each output of the network with respect to the network's parameters, placing them into an array, and computing the grammian. The only difference is that I am calculating it layerwise, and adding the results. 

In [8]:
#model_2.zero_grad()
#y = model_2(x_test)

#in the future we would iterate over layers instead of like this...
# layer_components_w1 = [] 
# layer_components_w2 = []
# layer_components_w3 = []
# layer_components_w4 = []
# layer_components_w5 = []

# layer_components_b1 = []
# layer_components_b2 = []
# layer_components_b3 = []
# layer_components_b4 = []
# layer_components_b5 = []

# for i in range(len(y)):
#     model_2.zero_grad()
#     y[i].backward(retain_graph=True)
#     #Get the tensors
#     w1_grad = model_2.d1.weight.grad.detach().numpy()
#     w2_grad = model_2.d2.weight.grad.detach().numpy()
#     w3_grad = model_2.d3.weight.grad.detach().numpy()
#     w4_grad = model_2.d4.weight.grad.detach().numpy()
#     w5_grad = model_2.d5.weight.grad.detach().numpy()
    
#     b1_grad = model_2.d1.bias.grad.detach().numpy()
#     b2_grad = model_2.d2.bias.grad.detach().numpy()
#     b3_grad = model_2.d3.bias.grad.detach().numpy()
#     b4_grad = model_2.d4.bias.grad.detach().numpy()
#     b5_grad = model_2.d5.bias.grad.detach().numpy()

#     #reshape and append. deep copy neccessary or else they are the same objects
#     layer_components_w1.append(w1_grad.reshape(-1).copy())
#     layer_components_w2.append(w2_grad.reshape(-1).copy())
#     layer_components_w3.append(w3_grad.reshape(-1).copy())
#     layer_components_w4.append(w4_grad.reshape(-1).copy())
#     layer_components_w5.append(w5_grad.reshape(-1).copy())
    
#     layer_components_b1.append(b1_grad.reshape(-1).copy())
#     layer_components_b2.append(b2_grad.reshape(-1).copy())
#     layer_components_b3.append(b3_grad.reshape(-1).copy())
#     layer_components_b4.append(b4_grad.reshape(-1).copy())
#     layer_components_b5.append(b5_grad.reshape(-1).copy())

# layer_components_w1 = np.array(layer_components_w1)
# layer_components_w2 = np.array(layer_components_w2)
# layer_components_w3 = np.array(layer_components_w3)
# layer_components_w4 = np.array(layer_components_w4)
# layer_components_w5 = np.array(layer_components_w5)

# layer_components_b1 = np.array(layer_components_b1)
# layer_components_b2 = np.array(layer_components_b2)
# layer_components_b3 = np.array(layer_components_b3)
# layer_components_b4 = np.array(layer_components_b4)
# layer_components_b5 = np.array(layer_components_b5)

# autograd_NTK = layer_components_w1 @ layer_components_w1.T+\
#     layer_components_w2 @ layer_components_w2.T+\
#     layer_components_w3 @ layer_components_w3.T+\
#     layer_components_w4 @ layer_components_w4.T+\
#     layer_components_w5 @ layer_components_w5.T+\
#     layer_components_b1 @ layer_components_b1.T+\
#     layer_components_b2 @ layer_components_b2.T+\
#     layer_components_b3 @ layer_components_b3.T+\
#     layer_components_b4 @ layer_components_b4.T+\
#     layer_components_b5 @ layer_components_b5.T

#autograd_NTK

# Now Layerwise

In [14]:
x_test = x_test.to('cpu')
x_6, x_5, x_4, x_3, x_2, x_1, x_0 = model(x_test)

#These need to be numpy
Ws = []
Ws.append(torch.tensor([0.0],dtype=torch.float32)) 
Ws.append(torch.tensor([0.0],dtype=torch.float32)) 
Ws.append(torch.tensor([0.0],dtype=torch.float32))
Ws.append(torch.tensor([0.0],dtype=torch.float32)) #spacer
Ws.append(torch.tensor([0.0],dtype=torch.float32))
Ws.append(model.d5.weight.detach())

#Kernel Matrices, Need to be numpy
Ks = []
Ks.append(model.d1.weight.detach())
Ks.append(model.d2.weight.detach())
Ks.append(model.d3.weight.detach())
Ks.append(model.d4.weight.detach())
Ks.append(torch.tensor([0.0],dtype=torch.float32)) #spacer
Ks.append(torch.tensor([0.0],dtype=torch.float32))


Xs = [] # Xs are shape (output x #DP) ; however, typical python notation is reversed, so we take transpose here
Xs.append(x_0.T.detach())
Xs.append(x_1.T.detach())
Xs.append(x_2.T.detach())
Xs.append(x_3.T.detach())
Xs.append(x_4.T.detach())
Xs.append(x_5.T.detach())

#This is used to create arrays-- needs to be integer list to play nice with compilers
ds_int = []
ds_int.append(width*3*3) #channels_out * kernel_height * kernel_width
ds_int.append(width*3*3) #channels_out * kernel_height * kernel_width
ds_int.append(width*3*3) #channels_out * kernel_height * kernel_width
ds_int.append(width*3*3) #channels_out * kernel_height * kernel_width
ds_int.append(1) #channels_out * kernel_height * kernel_width
ds_int.append(1) #channels_out * kernel_height * kernel_width

ds_array = [] #this is for the NTK formulation, 
#ds_array.append(torch.tensor([1.0],dtype=torch.float32).to(device)) #first element is a spacer, could be anything.

ds_array.append(torch.tensor([1.0],dtype=torch.float32).to(device)) #first element is a spacer, could be anything.
ds_array.append(torch.tensor([1.0],dtype=torch.float32).to(device)) #The rest, even if you dont use NTK formulation, would be 1
ds_array.append(torch.tensor([1.0],dtype=torch.float32).to(device))
ds_array.append(torch.tensor([1.0],dtype=torch.float32).to(device))
ds_array.append(torch.tensor([1.0],dtype=torch.float32).to(device))
ds_array.append(torch.tensor([1.0],dtype=torch.float32).to(device))

filters = []
filters.append(1)
filters.append(1)
filters.append(1)
filters.append(1)
filters.append(0)
filters.append(0)


padding = []
padding.append(1)
padding.append(1)
padding.append(1)
padding.append(1)
padding.append(0)
padding.append(0)


strides = []
strides.append(1)
strides.append(1)
strides.append(1)
strides.append(1)
strides.append(0)
strides.append(0)


layers=[model.d1,
        model.d2,
        model.d3,
        model.d4,
        0.0,
        model.d5
       ]



In [17]:
now = time.time()
ntk_components = compute_NTK_CNN(Ws, Ks, Xs, ds_int, ds_array, strides, padding, layers, d_activationt, device="cpu")
print(time.time() - now)

3.6703555583953857


In [18]:
%%timeit
ntk_components = compute_NTK_CNN(Ws, Ks, Xs, ds_int, ds_array, strides, padding, layers, d_activationt, device="cpu")

55.3 s ± 2.61 s per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [None]:
NTK = torch.sum(torch.stack(ntk_components),[0,])

In [None]:
NTK

In [None]:
autograd_NTK

In [None]:
np.allclose(NTK.cpu().numpy(),autograd_NTK,1e-3)