In [1]:
import numpy as np
import torch
import random

import matplotlib.pyplot as plt


import torch
from torch import nn, optim
from torch import load
from torch.nn import functional as F
from torch import autograd

from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader

#from easy_ntk import calculate_NTK
from einops import rearrange

import time

import sys
from pathlib import Path

from numba import njit
from numba.typed import List

In [2]:
SEED = 0

In [3]:
def activation(x):
    return torch.tanh(x)

def d_activationt(x):
    return torch.cosh(x)**-2

In [4]:
def NTK_weights(m):
    if isinstance(m, nn.Linear):
        print(m.weight.shape)
        nn.init.normal_(m.weight.data)#/m.weight.shape[0]
        if m.bias != None:
            nn.init.normal_(m.bias.data)#/m.weight.shape[0]
    if isinstance(m, nn.Conv2d):
        print(m.weight.shape)
        nn.init.normal_(m.weight.data)#/m.weight.shape[0]
        if m.bias != None:
            nn.init.normal_(m.bias.data)#/m.weight.shape[0]

In [5]:
class FC(torch.nn.Module):
    '''
    simple network for test cases
    
    
    It seems like bias vectors aren't trivially added.
    '''
    def __init__(self,):
        super(FC, self).__init__()
        #input size=(N,784)
        self.d1 = torch.nn.Linear(784,100,bias=False)

        self.d2 = torch.nn.Linear(100,100,bias=False)
        
        self.d3 = torch.nn.Linear(100,100,bias=False)
        
        self.d4 = torch.nn.Linear(100,1,bias=False)
        
    def forward(self, x0):
        x1 = 1/np.sqrt(100) * activation(self.d1(x0))
        x2 = 1/np.sqrt(100) * activation(self.d2(x1))
        x3 = 1/np.sqrt(100) * activation(self.d3(x2))
        x4 = self.d4(x3)
        return x4, x3, x2, x1, x0

In [6]:
torch.manual_seed(SEED)
random.seed(SEED)
np.random.seed(SEED)
device='cuda'

model = FC()
model.to(device)
model.apply(NTK_weights)

torch.manual_seed(SEED)
random.seed(SEED)
np.random.seed(SEED)

model2 = FC()
model2.to(device)
model2.apply(NTK_weights)

x_test = np.random.normal(0,1,(3000,784)).astype(np.float32) #n c_in, h, w
x_test = torch.from_numpy(x_test)
x_test = x_test.to('cuda')

torch.Size([100, 784])
torch.Size([100, 100])
torch.Size([100, 100])
torch.Size([1, 100])
torch.Size([100, 784])
torch.Size([100, 100])
torch.Size([100, 100])
torch.Size([1, 100])


In [7]:
x_4, x_3, x_2, x_1, x_0 = model(x_test)

#These need to be numpy
Ws = []
Ws.append(model.d1.weight.detach())
Ws.append(model.d2.weight.detach())
Ws.append(model.d3.weight.detach())
Ws.append(model.d4.weight.detach())

#Kernel Matrices, Need to be numpy
Ks = []
Ks.append(torch.tensor([0.0],dtype=torch.float32)) 
Ks.append(torch.tensor([0.0],dtype=torch.float32))
Ks.append(torch.tensor([0.0],dtype=torch.float32))
Ks.append(torch.tensor([0.0],dtype=torch.float32))


Xs = [] # Xs are shape (output x #DP) ; however, typical python notation is reversed, so we take transpose here
Xs.append(x_0.T.detach())
Xs.append(x_1.T.detach())
Xs.append(x_2.T.detach())
Xs.append(x_3.T.detach())

#This is used to create arrays-- needs to be integer list to play nice with compilers
d_int = []
d_int.append(100)
d_int.append(100)
d_int.append(100)
d_int.append(100)

d_array = [] #this is for the NTK formulation, 
#ds_array.append(torch.tensor([1.0],dtype=torch.float32).to(device)) #first element is a spacer, could be anything.

d_array.append(torch.tensor([100.0],dtype=torch.float32).to(device)) 
d_array.append(torch.tensor([100.0],dtype=torch.float32).to(device)) 
d_array.append(torch.tensor([100.0],dtype=torch.float32).to(device))
d_array.append(torch.tensor([100.0],dtype=torch.float32).to(device))
d_array.append(torch.tensor([100.0],dtype=torch.float32).to(device))

filters = []
filters.append(0)
filters.append(0)
filters.append(0)
filters.append(0)

padding = []
padding.append(0)
padding.append(0)
padding.append(0)
padding.append(0)

strides = []
strides.append(0)
strides.append(0)
strides.append(0)
strides.append(0)

layers=[model.d1,
        model.d2,
        model.d3,
        model.d4,
       ]

In [8]:
from easy_ntk import compute_NTK_CNN

In [9]:
components = compute_NTK_CNN(Ws, Ks, Xs, d_int, d_array, strides, padding, layers, d_activationt, device="cuda",)

In [10]:
NTK = torch.sum(torch.stack(components),dim=0).cpu().numpy()

# First, can we show the NTK equation is correct?

In [11]:
#in order to do that we need the initial parameters and the parameters after an update step.

#Ws are the initial weights

In [12]:
assert torch.all(model2.d1.weight == model.d1.weight)
assert torch.all(model2.d2.weight == model.d2.weight)
assert torch.all(model2.d3.weight == model.d3.weight)
assert torch.all(model2.d4.weight == model.d4.weight)

In [13]:
# so lets do an update step. lets pretend we are solving a regression problem, so we need to get the targets

y_test = torch.sin(np.pi*100*torch.mean(x_test,dim=1))[:,None]

In [14]:
lr = 1e-2

out = model(x_test)[0]

out.retain_grad()

loss = torch.sum((y_test - out)**2)

loss.backward()

with torch.no_grad():
    for name, W in model.named_parameters():
        W-= lr*W.grad

dwc = out.grad.detach().cpu().numpy()

In [15]:
#okay so now model has the single update, model 2 has the first update.
#if lr is small enough, then the ntk equation holds

In [16]:
leftside = (model(x_test)[0] - model2(x_test)[0]).detach().cpu().numpy()

In [17]:
rightsides = []
for j in range(len(leftside)):
    rightside = 0
    for i in range(len(leftside)):
        rightside+=NTK[j,i] * dwc[i,0]
    rightside*= lr
    rightsides.append(rightside)
    
rightside = np.array(rightsides)

In [18]:
rightside[0]

-1.405895449149557

In [19]:
leftside[0]

array([0.683164], dtype=float32)

In [20]:
np.sum(np.isclose(leftside,rightside, rtol=1e-3, atol=1e-100))

#411  1e-2
#1693 1e-3
#1373 1e-4
#1392 1e-5
#1231 1e-6
#918  1e-7
#0 1e-9 (I think we hit computer precision for pytorch)

411