In [128]:
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import pickle
import os
import torch
from torch import nn

In [129]:
N_datapoints = 10_000
input_features = 5
hidden_layer = 2000

In [134]:
def activation(x):
    return torch.tanh(x)
    
def d_activation(x):
    return torch.cosh(x)**-2

# def activation(x):
#     return 2/np.pi * torch.arctan(np.pi/2 * x)

# def d_activation(x):
#     return 4/(np.pi**2 * x**2 + 4)

In [135]:
X = np.random.normal(0,1,(N_datapoints,input_features)).astype(np.float32)
X = torch.from_numpy(X)

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.fc1 = nn.Linear(input_features,hidden_layer,bias=False)
        self.fc2 = nn.Linear(hidden_layer,1,bias=False)
    def forward(self, x_0):
        x_1 = activation(self.fc1(x_0))
        x_2 = self.fc2(x_1) # instead of Heaviside step fn
        return x_2, x_1, x_0
    
model = Model()
x_2, x_1, x_0 = model(X)

In [136]:
Ws = []
Ws.append(model.fc1.weight.detach().numpy())
Ws.append(model.fc2.weight.detach().numpy())

Xs = []
Xs.append(x_0.numpy().T)
Xs.append(x_1.detach().numpy().T)

ds = []
ds.append(5)
ds.append(model.fc1.weight.shape[0])

In [137]:
def cross(X):
    return np.dot(np.transpose(X),X)

In [138]:
# Ws =     [W1, W2, ..., WL, w]
# Xs = [X0,X1, ..., XL]
# d = [d0, d1, ..., dL]

#therefore these all have the same lengths

def compute_NTK(Ws,Xs,d): #L counts from 1 to number of layers.
    L = len(Xs)-1 #number of layers, Xs goes from inputs to outputs; X_0 is the input, X_L is the output
    n = Xs[0].shape[1] #number of datapoints
    Ds = [[]] #holds the derivatives, first value is empty list...?
    for l in range(L):
        Ds.append(d_activation(np.dot(Ws[l],Xs[l])))
    KNTK = cross(Xs[L]) #this is eventually summed over
    for l in range(1,L+1):
        #we are going to construct terms that look like ( S^T S ) * (X^T X)
        XtX = cross(Xs[l-1])
        S = np.zeros((d[l],n))
        for i in range(n):
            s = Ws[-1].T.reshape(-1)/np.sqrt(d[L])
            for k in range(L,l-1,-1):
                s = Ds[k][:,i]*s
                if k > l:
                    s = np.dot(Ws[k-1],s)/np.sqrt(d[k-1])
            S[:,i] = s
        KNTK += cross(S) * XtX
    return KNTK

In [139]:
print('Number of parameters: ',sum(p.numel() for p in model.parameters() if p.requires_grad))
print('Number of datapoints: ',N_datapoints)

Number of parameters:  12000
Number of datapoints: 


In [124]:
%%timeit
KNTK = compute_NTK(Ws,Xs,ds)

6.65 s ± 402 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
