# Synopsis: 1) Jax still doesn't work on Windows, but 2) jax jit is faster than numba for small networks and larger number of datapoints. (haven't tested other scenarios)



In [None]:
import numpy as np
import torch
import random

import matplotlib.pyplot as plt

import torch
from torch import nn, optim
from torch import load
from torch.nn import functional as F
from torch import autograd

from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader

import time

import sys
from pathlib import Path

from numba import njit
from jax import numpy as jnp

In [None]:
#Identity
def activation(x):
    return x

#@njit
def d_activation(x):
    #return 1 
    return np.ones(np.shape(x),dtype=np.float32) #this should be a differnt value...?
    #return np.eye(np.shape(x)[0],np.shape(x)[1])

#Tanh
# def activation(x):
#     return torch.tanh(x)

#@njit
# def d_activation(x):
#     return np.cosh(x)**-2

In [None]:
d_activation = njit(d_activation)

In [None]:
input_features = 5
hidden_layer = 128
N_datapoints = 100
SEED = 1

In [None]:

def cross(X):
    return np.dot(np.transpose(X),X)


def compute_NTK(Ws, Xs, d_int, d_array):#L counts from 1 to number of layers.
    '''
    I should add some docstring
    
    Ws, a list of the weights as np.array type np.float32,                          [W1, W2, W3 ... W]
    Xs, a list of the conjugate kernels as np.array type np.float32,            [X0, X1, X2, ... XL]
    d_int, a list of the dimensionality of X_l as int64,                        [d0, d1, d2, ... dL]
    d_array, a list of the dimensionality of X_l, as np.array type np.float 32, [d0, d1, d2, ... dL] 
    all of this is neccessary because numba doesnt like type conversion.
    
    outputs the NTK as a np.array of type np.float32
    '''
    L = len(Xs)-1 #number of layers, Xs goes from inputs to right before outputs; X_0 is the input, X_L CK
    n = Xs[0].shape[1] #number of datapoints
    Ds = [np.array([[0.0]],dtype=np.float32)] #holds the derivatives, first value is empty list...?; just a spacer, replace with array
    for l in range(L):
        Ds.append(d_activation(np.dot(Ws[l],Xs[l])))
    KNTK = cross(Xs[L]) #this is eventually summed over
    #print(L+1,KNTK)
    for l in range(1,L+1):
        #we are going to construct terms that look like ( S^T S ) * (X^T X)
        XtX = cross(Xs[l-1])
        S = np.zeros((d_int[l],n),dtype=np.float32)
        for i in range(n):
            s = Ws[-1].T.reshape(-1)/np.sqrt(d_array[L])
            for k in range(L,l-1,-1):
                s = Ds[k][:,i]*s
                if k > l:
                    s = np.dot(Ws[k-1],s)/np.sqrt(d_array[k-1])
            S[:,i] = s
        #print(l,cross(S)*XtX)
        KNTK += cross(S) * XtX
    return KNTK

In [None]:
cross = njit(cross)

compute_NTK = njit(compute_NTK)

In [None]:
def NTK_weights(m):
    if isinstance(m, nn.Linear):
        print(m.weight.shape)
        nn.init.normal_(m.weight.data)#/m.weight.shape[0]
        if m.bias != None:
            nn.init.normal_(m.bias.data)#/m.weight.shape[0]
    if isinstance(m, nn.Conv2d):
        print(m.weight.shape)
        nn.init.normal_(m.weight.data)#/m.weight.shape[0]
        if m.bias != None:
            nn.init.normal_(m.bias.data)#/m.weight.shape[0]

In [None]:
#Layerwise Needs each conjugate Kernel
class dumb_small_layerwise(torch.nn.Module):
    '''
    simple network for test cases
    '''
    def __init__(self,):
        super(dumb_small_layerwise, self).__init__()
        
        self.d1 = torch.nn.Linear(5,256,bias=False)
        self.d2 = torch.nn.Linear(256,256,bias=False)
        self.d3 = torch.nn.Linear(256,1,bias=False)
        
    def forward(self, x_0):
        x_1 = activation(self.d1(x_0)) / np.sqrt(256)
        x_2 = activation(self.d2(x_1)) / np.sqrt(256)
        x_3 = activation(self.d3(x_2))
        
        return x_3, x_2, x_1, x_0

In [None]:
torch.manual_seed(SEED)
random.seed(SEED)
np.random.seed(SEED)
device='cpu'

model_small = dumb_small_layerwise()
model_small.to(device)
model_small.apply(NTK_weights)

#Reset the seed and 
torch.manual_seed(SEED)
random.seed(SEED)
np.random.seed(SEED)


model_layerwise = dumb_small_layerwise()
model_layerwise.to(device)
model_layerwise.apply(NTK_weights)

x_test = np.random.normal(0,1,(200,5)).astype(np.float32)
x_test = torch.from_numpy(x_test)


In [None]:
x_3, x_2, x_1, x_0 = model_layerwise(x_test)

Ws = []
Ws.append(model_layerwise.d1.weight.detach().numpy().astype(np.float32))
Ws.append(model_layerwise.d2.weight.detach().numpy().astype(np.float32))
Ws.append(model_layerwise.d3.weight.detach().numpy().astype(np.float32))

Xs = [] # Xs are shape (output x #DP) ; however, typical python notation is reversed, so we take transpose here
Xs.append(x_0.detach().numpy().T.astype(np.float32))
Xs.append(x_1.detach().numpy().T.astype(np.float32))
Xs.append(x_2.detach().numpy().T.astype(np.float32))

ds_int = []
ds_int.append(5)
ds_int.append(256)
ds_int.append(256)

ds_array = []
ds_array.append(np.array([5.0],dtype=np.float32)) #first element is the input length
ds_array.append(np.array([256.0],dtype=np.float32))
ds_array.append(np.array([256.0],dtype=np.float32)) #the remaining elements are the output lengths, but omit the last output length assumed 1.

NTK_layerwise = compute_NTK(Ws, Xs, ds_int, ds_array)

In [None]:
%%timeit #288ms
NTK_layerwise = compute_NTK(Ws, Xs, ds_int, ds_array)

In [None]:
def jax_cross(X):
    return jnp.dot(jnp.transpose(X),X)

def jax_d_activation(x):
    return jnp.array(np.ones(np.shape(x),dtype=np.float32)) #this should be a differnt value...?

  

In [None]:
def jax_compute_NTK(Ws, Xs, d_int, d_array):#L counts from 1 to number of layers.
    '''
    I should add some docstring
    
    Ws, a list of the weights as np.array type np.float32,                          [W1, W2, W3 ... W]
    Xs, a list of the conjugate kernels as np.array type np.float32,            [X0, X1, X2, ... XL]
    d_int, a list of the dimensionality of X_l as int64,                        [d0, d1, d2, ... dL]
    d_array, a list of the dimensionality of X_l, as np.array type np.float 32, [d0, d1, d2, ... dL] 
    all of this is neccessary because numba doesnt like type conversion.
    
    outputs the NTK as a np.array of type np.float32
    '''
    L = len(Xs)-1 #number of layers, Xs goes from inputs to right before outputs; X_0 is the input, X_L CK
    n = Xs[0].shape[1] #number of datapoints
    Ds = [jnp.array(np.array([[0.0]],dtype=np.float32))] #holds the derivatives, first value is empty list...?; just a spacer, replace with array
    for l in range(L):
        Ds.append(jax_d_activation(jnp.dot(Ws[l],Xs[l])))
    KNTK = jax_cross(Xs[L]) #this is eventually summed over
    #print(L+1,KNTK)
    for l in range(1,L+1):
        #we are going to construct terms that look like ( S^T S ) * (X^T X)
        XtX = jax_cross(Xs[l-1])
        S = jnp.array(np.zeros((d_int[l],n),dtype=np.float32)) #d_int is a 'tracer'
        for i in range(n):
            s = Ws[-1].T.reshape(-1)/jnp.sqrt(d_array[L])
            for k in range(L,l-1,-1):
                s = Ds[k][:,i]*s
                if k > l:
                    s = jnp.dot(Ws[k-1],s)/jnp.sqrt(d_array[k-1])
            S.at[:,i].set(s)
        #print(l,cross(S)*XtX)
        KNTK += jax_cross(S) * XtX
    return KNTK

In [None]:
jax_NTK_layerwise = jax_compute_NTK(Ws, Xs, ds_int, ds_array)

In [None]:
from jax import jit

jax_compute_NTK = jit(jax_compute_NTK, static_argnums=(2,))
jax_cross = jit(jax_cross)
jax_d_activation = jit(jax_d_activation)

In [None]:
ds_int_tuple = (5, 256, 256)

In [None]:
jaxjit_NTK_layerwise = jax_compute_NTK(Ws, Xs, ds_int_tuple, ds_array)