# This notebook demonstrates that the various methods all agree with one another, and demonstrates how to setup the calculation correctly for each

In [66]:
import numpy as np
import torch
import random

import matplotlib.pyplot as plt


import torch
from torch import nn, optim
from torch import load
from torch.nn import functional as F
from torch import autograd

from torchvision import datasets

from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader

import time

import sys
from pathlib import Path

from numba import njit

import os

import gc

In [67]:
from torchntk.autograd import naive_ntk
from torchntk.autograd import old_autograd_ntk
from torchntk.explicit import explicit_ntk
from torchntk.autograd import vmap_ntk_loader
from torchntk.autograd import autograd_components_ntk

In [68]:
SEED = 0
HOW_MANY = 3 #just to demonstrate different methods, confirm same answer

# Architecture Definition

In [69]:
def activation(x):
    return torch.tanh(x)

def d_activationt(x):
    return torch.cosh(x)**-2

In [70]:
def NTK_weights(m):
    if isinstance(m, nn.Linear):
        print(m.weight.shape)
        nn.init.normal_(m.weight.data)#/m.weight.shape[0]
        if m.bias != None:
            nn.init.normal_(m.bias.data)#/m.weight.shape[0]

#### We need two different, but similar, architectures

* FC_layered, which we use for the explicit, returns all intermediate layer outputs which we will use in our calculation
* FC, which has a typical return of just the single output neuron of the network


We will use seed to set the weights, s.t. we can gurantee our two networks are the same

In [71]:
class FC_layered(torch.nn.Module):
    '''
    simple network for test cases
    
    
    It seems like bias vectors aren't trivially added.
    '''
    def __init__(self,):
        super(FC_layered, self).__init__()
        #input size=(N,784)
        self.d1 = torch.nn.Linear(100,10,bias=False)

        self.d2 = torch.nn.Linear(10,100,bias=False)
        
        self.d3 = torch.nn.Linear(100,10,bias=False)
        
        self.d4 = torch.nn.Linear(10,1,bias=False)
        
    def forward(self, x0):
        x1 = 1/np.sqrt(10) * activation(self.d1(x0))
        x2 = 1/np.sqrt(100) * activation(self.d2(x1))
        x3 = 1/np.sqrt(10) * activation(self.d3(x2))
        x4 = self.d4(x3)
        return x4, x3, x2, x1, x0 #NOTICE return all intermediate outputs 
    
class FC(torch.nn.Module):
    '''
    simple network for test cases
    
    
    It seems like bias vectors aren't trivially added.
    '''
    def __init__(self,):
        super(FC, self).__init__()
        #input size=(N,784)
        self.d1 = torch.nn.Linear(100,10,bias=False)

        self.d2 = torch.nn.Linear(10,100,bias=False)
        
        self.d3 = torch.nn.Linear(100,10,bias=False)
        
        self.d4 = torch.nn.Linear(10,1,bias=False)
        
    def forward(self, x0):
        x1 = 1/np.sqrt(10) * activation(self.d1(x0))
        x2 = 1/np.sqrt(100) * activation(self.d2(x1))
        x3 = 1/np.sqrt(10) * activation(self.d3(x2))
        x4 = self.d4(x3)
        return x4

In [72]:
torch.manual_seed(SEED)
random.seed(SEED)
np.random.seed(SEED)
device='cuda'

model = FC()
model.to(device)
model.apply(NTK_weights)
####

#layerwise depends on a special nn.Module object
#that returns the tuple over each layer's output
torch.manual_seed(SEED)
random.seed(SEED)
np.random.seed(SEED)

model2 = FC_layered()
model2.to(device)
model2.apply(NTK_weights)

torch.Size([10, 100])
torch.Size([100, 10])
torch.Size([10, 100])
torch.Size([1, 10])
torch.Size([10, 100])
torch.Size([100, 10])
torch.Size([10, 100])
torch.Size([1, 10])


FC_layered(
  (d1): Linear(in_features=100, out_features=10, bias=False)
  (d2): Linear(in_features=10, out_features=100, bias=False)
  (d3): Linear(in_features=100, out_features=10, bias=False)
  (d4): Linear(in_features=10, out_features=1, bias=False)
)

# Now setup Data

In [73]:
torch.manual_seed(SEED)
random.seed(SEED)
np.random.seed(SEED)
train_x = torch.empty((HOW_MANY,100),device='cpu').normal_(0,1)

# Torch.Jacobian

In [74]:
#%%timeit
if HOW_MANY <= 1000: #takes a long time
    NTK_naive = naive_ntk(model.cpu(),train_x_cpu)

# NTK old Autograd

In [75]:
model.to(device)

FC(
  (d1): Linear(in_features=100, out_features=10, bias=False)
  (d2): Linear(in_features=10, out_features=100, bias=False)
  (d3): Linear(in_features=100, out_features=10, bias=False)
  (d4): Linear(in_features=10, out_features=1, bias=False)
)

In [76]:
from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset

In [77]:
tgts = torch.zeros(HOW_MANY,device='cpu')
xloader = DataLoader(TensorDataset(train_x,tgts),batch_size=128,pin_memory=True,shuffle=False)

In [78]:
NTK_old_autograd = old_autograd_ntk(xloader,model,device='cuda')

# NTK layerwise Autograd

In [79]:
y = model(train_x.to(device))
NTK_autograd_components = autograd_components_ntk(model,y[:,0])
NTK_autograd_full = torch.sum(torch.stack([val for val in NTK_autograd_components.values()]),dim=0)

# NTK layerwise by hand

In [80]:
def forward(model, x_test):
    x_4, x_3, x_2, x_1, x_0 = model(x_test)


    Xs = [] # Xs are shape (output x #DP) ; however, typical python notation is reversed, so we take transpose here
    Xs.append(x_0.T.detach())
    Xs.append(x_1.T.detach())
    Xs.append(x_2.T.detach())
    Xs.append(x_3.T.detach())

    #This is used to create arrays-- needs to be integer list to play nice with compilers
    d_int = []
    d_int.append(10)
    d_int.append(100)
    d_int.append(10)
    d_int.append(1)

    d_array = [] #this is for the NTK formulation, 
    #ds_array.append(torch.tensor([1.0],dtype=torch.float32).to(device)) #first element is a spacer, could be anything.

    d_array.append(torch.tensor([10.0],dtype=torch.float32).to(device)) 
    d_array.append(torch.tensor([100.0],dtype=torch.float32).to(device)) 
    d_array.append(torch.tensor([10.0],dtype=torch.float32).to(device))
    d_array.append(torch.tensor([1.0],dtype=torch.float32).to(device))

    layers=[model.d1,
            model.d2,
            model.d3,
            model.d4,
           ]
    
    #Ws Ks Xs d_int d_array strides padding layers d_activationt, device="cuda"
    return x_4[:,0], {'Xs':Xs, 'd_int':d_int, 'd_array':d_array, 'layers':layers, 'd_activationt':d_activationt, 'device':'cuda'}

In [81]:
model2.zero_grad()

output, params = forward(model2, train_x.to(device))

NTK_explicit_components = explicit_ntk(**params)
NTK_explicit_full = torch.sum(torch.stack(NTK_explicit_components),dim=0)

# NTK with Vmap

In [82]:
NTK_vmap_components = vmap_ntk_loader(model,xloader)
NTK_vmap_full = torch.sum(torch.stack([val for val in NTK_vmap_components.values()]),dim=0)

# Compare full NTKs

In [88]:
NTK_naive

tensor([[ 5.8789, -0.0399, -0.1851],
        [-0.0399,  0.5906,  0.3424],
        [-0.1851,  0.3424,  2.0714]])

In [89]:
NTK_old_autograd

tensor([[ 5.8831, -0.0398, -0.1852],
        [-0.0398,  0.5904,  0.3423],
        [-0.1852,  0.3423,  2.0716]], device='cuda:0')

In [90]:
NTK_autograd_full

tensor([[ 5.8831, -0.0398, -0.1852],
        [-0.0398,  0.5904,  0.3423],
        [-0.1852,  0.3423,  2.0716]], device='cuda:0')

In [91]:
NTK_explicit_full

tensor([[ 5.8781, -0.0398, -0.1851],
        [-0.0398,  0.5903,  0.3423],
        [-0.1851,  0.3423,  2.0708]], device='cuda:0')

In [92]:
NTK_vmap_full

tensor([[ 5.8831, -0.0398, -0.1852],
        [-0.0398,  0.5904,  0.3423],
        [-0.1852,  0.3423,  2.0716]], device='cuda:0')