In [1]:
import numpy as np
import matplotlib.pyplot as plt
from numba import njit
import os
from tqdm import tqdm
import torch

#from layerwise_ntk import compute_NTK_CNN
import numpy as np
import random
#import matplotlib.pyplot as plt

import torch
from torch import nn, optim
from torch import load
from torch.nn import functional as F
from torch import autograd

from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader

import time

from numba import njit

In [2]:
def component_ntk(Ws: list, Ks: list, Xs: list, d_int: list, d_array: list, strides: list, padding: list, layers: list, d_activationt, device="cuda",) -> list:
    '''    
    Inputs: 
    Ws: list, has length number of layers + any reshaping, contains dense layer weight tensors, detatched, and on device
    Ks: list, has length number of layers + any reshaping, contains 2d convolutional layers weight tensors, detatched, and on device
    Xs: list, has length number of layers + any reshaping, contains all intermediate outputs of each layer in the models
    d_int: list, has the number of bias parameters in each dense layer in the models
    d_array: list, has the value of which to sqrt and divide by in each layer, typically called the NTK normalization. else, its values are 1.
    strides: list, has the value of stride in each convolutional layer, else 0
    padding: list, has the value of padding in each convolutional layer, else 0
    layers: list, is a list containing the pytorch layers in the model, and "0" as a placeholder for a reshaping layer
    d_activationt, a function containing the derivative of the activation function used in all layers but the output layer, composed of pytorch functions
    device: str, one of either 'cpu' or 'cuda'; must be the same as the model device location
    
    NOTE: all of the above lists should have the same length! See example
    
    OUTPUTS: list of torch.tensor objects, each the ntk 'component' for that layer. Given in backwards order, i.e. starting with the last layer. First weight, then bias.
    
    NOTE: to get the full ntk, simply sum over the layer dimension of the result: NTK = torch.sum(torch.stack(components),dim=(0))
    
    updated
    '''
    components = []
    
    L = len(Xs)-1 #number of layers, Xs goes from inputs to right before outputs; X_0 is the input, X_L CK
    n = Xs[0].shape[-1] #number of datapoints

    #holds the derivatives of activation, first value is empty list...?; just a spacer, replace with array
    Ds_dense = [np.array([[0.0]],dtype=np.float32)] 
    Ds_conv = [np.array([[0.0]],dtype=np.float32)]
    s_matrices = []
    with torch.no_grad():
        ####################################################################################################
        for l in range(0,L):
            if isinstance(layers[l],torch.nn.Linear):
                Ds_dense.append(d_activationt(layers[l](Xs[l].T)).T)
            else:
                Ds_dense.append(np.array([[0.0]],dtype=np.float32))
        ################################################################################################
        for l in range(0,L):
            if isinstance(layers[l],torch.nn.Conv2d):
                Ds_conv.append(d_activationt(layers[l](Xs[l].T)).reshape(n,-1).T)
            else:
                Ds_conv.append(np.array([[0.0]],dtype=np.float32))      
        ####################################################################################################
        S = torch.tensor([1.0],dtype=torch.float32).to(device) #this models the backward propogation:   
        for l in range(L,-1,-1):
            if isinstance(layers[l], torch.nn.Linear):
                components.append(cross_pt_nonp(S,device)*cross_pt_nonp(Xs[l],device)/d_array[l])
                if not(hasattr(layers[l], 'bias')) and not(getattr(layers[l], 'bias') is None):
                    W = torch.ones((d_int[l],n),dtype=torch.float32).to(device) * S
                    components.append(cross_pt_nonp(W,device).to(device)/d_array[l])

            elif isinstance(layers[l], torch.nn.Conv2d):
                if len(S.shape) == 2: #this should only affect the very last layer, at which point, who cares.
                    S = S[None,:,:]
                    
                dw = calc_dw(x=Xs[l].T.cpu().numpy(),w=Ks[l].cpu().numpy(),b=0,pad=padding[l],stride=strides[l],H_=Xs[l+1].shape[1],W_=Xs[l+1].shape[0])
                
                W = torch.matmul(torch.from_numpy(dw).to(device),S.to(device))
                
                W = torch.diagonal(W,0,0,2)

                components.append(cross_pt_nonp(W,device).to(device)/d_array[l])
                
                if not(hasattr(layers[l], 'bias')) and not(getattr(layers[l], 'bias') is None):
                    N = Ks[l].shape[0]
                    W = np.split(S.cpu().numpy(),N,axis=1)
                    
                    W = np.array(W)
                    
                    W = np.sum(W,axis=(1,2))
                    
                    components.append(torch.from_numpy(cross(W,)).to(device)/d_array[l])

            #############################
            #now we setup S for the next loop by treating appropriately
            if l==0:
                break

            if isinstance(layers[l], torch.nn.Linear):
                S = torch.matmul(S.T,Ws[l]).T / torch.sqrt(d_array[l])
                if len(S.shape) < 2:
                    S = S[:,None] #expand dimension along axis 1
                if not(isinstance(layers[l-1],float)): #this exludes the reshaping layer
                    S = Ds_dense[l]*S
                else: #and when the reshaping layer occurs we need to apply this instead
                    S = Ds_conv[l-1]*S

            elif isinstance(layers[l], torch.nn.Conv2d):
                dx = calc_dx(x=Xs[l].T.cpu().numpy(),w=Ks[l].cpu().numpy(),b=0,pad=padding[l],stride=strides[l],H_=Xs[l+1].shape[1],W_=Xs[l+1].shape[0])
                S = (torch.from_numpy(dx[None,:,:]).to(device) @ S) / torch.sqrt(d_array[l])
                S = Ds_conv[l]*S
            
    return components

In [3]:
def activation(x):
    return torch.tanh(x)

def d_activationt(x):
    return torch.cosh(x)**-2

In [4]:
class model(torch.nn.Module):
    def __init__(self,):
        super(model, self).__init__()
        
        self.d1 = torch.nn.Linear(10,10,bias=False)

        self.d2 = torch.nn.Linear(10,10,bias=False)
        
        self.d3 = torch.nn.Linear(10,10,bias=False)
        
        self.d4 = torch.nn.Linear(10,1,bias=False)
        
    def forward(self, x_0):
        x_1 = activation(self.d1(x_0))
        x_2 = activation(self.d2(x_1))
        x_3 = activation(self.d3(x_2))
        x_4 = activation(self.d4(x_3))
        return x_4, x_3, x_2, x_1, x_0

In [5]:
def NTK_weights(m):
    if isinstance(m, nn.Linear):
        print(m.weight.shape)
        nn.init.normal_(m.weight.data)#/m.weight.shape[0]
        if m.bias != None:
            nn.init.normal_(m.bias.data)#/m.weight.shape[0]
    if isinstance(m, nn.Conv2d):
        print(m.weight.shape)
        nn.init.normal_(m.weight.data)#/m.weight.shape[0]
        if m.bias != None:
            nn.init.normal_(m.bias.data)#/m.weight.shape[0]

In [180]:
SEED=0
how_many=3

import random

torch.manual_seed(SEED)
random.seed(SEED)
np.random.seed(SEED)
device='cuda'

mymodel = model()
mymodel.apply(NTK_weights)
mymodel.to(device)

x_test = np.random.normal(0,1,(how_many,10)).astype(np.float32) #n c_in, h, w
x_test = torch.from_numpy(x_test)


torch.Size([10, 10])
torch.Size([10, 10])
torch.Size([10, 10])
torch.Size([1, 10])


In [181]:
hasattr(mymodel.d1,'bias') and not(getattr(mymodel.d1, 'bias') is None)

False

# Calculate the Fisher Matrix

In [182]:
def cross_pt_nonp(X,device='cuda'):
    X = X.to(device)
    return X.T.matmul(X)

def notthatsimple(X,device='cuda'):
    X = X.to(device)
    return X.matmul(X.T)

In [207]:
def component_fisher(Ws: list, Ks: list, Xs: list, d_int: list, d_array: list, strides: list, padding: list, layers: list, d_activationt, device="cuda",) -> list:
    '''    
    Inputs: 
    Ws: list, has length number of layers + any reshaping, contains dense layer weight tensors, detatched, and on device
    Ks: list, has length number of layers + any reshaping, contains 2d convolutional layers weight tensors, detatched, and on device
    Xs: list, has length number of layers + any reshaping, contains all intermediate outputs of each layer in the models
    d_int: list, has the number of bias parameters in each dense layer in the models
    d_array: list, has the value of which to sqrt and divide by in each layer, typically called the NTK normalization. else, its values are 1.
    strides: list, has the value of stride in each convolutional layer, else 0
    padding: list, has the value of padding in each convolutional layer, else 0
    layers: list, is a list containing the pytorch layers in the model, and "0" as a placeholder for a reshaping layer
    d_activationt, a function containing the derivative of the activation function used in all layers but the output layer, composed of pytorch functions
    device: str, one of either 'cpu' or 'cuda'; must be the same as the model device location
    
    NOTE: all of the above lists should have the same length! See example
    
    OUTPUTS: list of torch.tensor objects, each the ntk 'component' for that layer. Given in backwards order, i.e. starting with the last layer. First weight, then bias.
    
    NOTE: to get the full ntk, simply sum over the layer dimension of the result: NTK = torch.sum(torch.stack(components),dim=(0))
    
    updated
    '''
    components = []
    
    L = len(Xs)-1 #number of layers, Xs goes from inputs to right before outputs; X_0 is the input, X_L CK
    n = Xs[0].shape[-1] #number of datapoints

    #holds the derivatives of activation, first value is empty list...?; just a spacer, replace with array
    Ds_dense = [np.array([[0.0]],dtype=np.float32)] 
    s_matrices = []
    with torch.no_grad():
        ####################################################################################################
        for l in range(0,L):
            if isinstance(layers[l],torch.nn.Linear):
                Ds_dense.append(d_activationt(layers[l](Xs[l].T)).T)
            else:
                Ds_dense.append(np.array([[0.0]],dtype=np.float32))
        ####################################################################################################
        S = torch.tensor([1.0],dtype=torch.float32).to(device) #this models the backward propogation:   
        for l in range(L,-1,-1):
            if isinstance(layers[l], torch.nn.Linear):
                print('S: ',S.shape)
                print('Xs: ',Xs[l].shape)
                components.append(S*Xs[l]/d_array[l])
                print(components[-1].shape)
                if (hasattr(layers[l], 'bias')) and not(getattr(layers[l], 'bias') is None):
                    W = torch.ones((d_int[l],n),dtype=torch.float32).to(device) * S
                    components.append(W.to(device)/d_array[l])
            
            #############################
            #now we setup S for the next loop by treating appropriately
            if l==0:
                break

            if isinstance(layers[l], torch.nn.Linear):
                print("W: ",Ws[l].shape)
                S = torch.matmul(S,Ws[l]).T / torch.sqrt(d_array[l])
                if len(S.shape) < 2:
                    S = S[:,None] #expand dimension along axis 1
                if not(isinstance(layers[l-1],float)): #this exludes the reshaping layer
                    print("D: ",Ds_dense[l].shape)
                    S = Ds_dense[l]*S
                else: #and when the reshaping layer occurs we need to apply this instead
                    S = Ds_conv[l-1]*S
            
    return components

In [215]:
Xs[-2].shape

torch.Size([10, 3])

In [214]:
D = d_activationt(mymodel.d3(Xs[-2].T))
D.shape

torch.Size([3, 10])

In [220]:
Ws[-1].shape #first index is the output index

torch.Size([1, 10])

In [None]:
def dw_dw(X):
    

In [221]:
torch.matmul((Ws[-1] * D), (Xs[-2])).shape # N x P = 3 x 100

torch.Size([3, 3])

In [208]:
x_test = x_test.to('cuda')
x_4, x_3, x_2, x_1, x_0 = mymodel(x_test)

#These need to be numpy
Ws = []
Ws.append(mymodel.d1.weight.detach())
Ws.append(mymodel.d2.weight.detach())
Ws.append(mymodel.d3.weight.detach())
Ws.append(mymodel.d4.weight.detach())

#Kernel Matrices, Need to be numpy
Ks = []
Ks.append(mymodel.d1.weight.detach())
Ks.append(mymodel.d2.weight.detach())
Ks.append(mymodel.d3.weight.detach())
Ks.append(mymodel.d4.weight.detach())


Xs = [] # Xs are shape (output x #DP) ; however, typical python notation is reversed, so we take transpose here
Xs.append(x_0.T.detach())
Xs.append(x_1.T.detach())
Xs.append(x_2.T.detach())
Xs.append(x_3.T.detach())
#Xs.append(x_4.T.detach())

#This is used to create arrays-- needs to be integer list to play nice with compilers
ds_int = []
ds_int.append(1) #channels_out * kernel_height * kernel_width
ds_int.append(1) #channels_out * kernel_height * kernel_width
ds_int.append(1) #channels_out * kernel_height * kernel_width
ds_int.append(1) #channels_out * kernel_height * kernel_width
ds_int.append(1) #channels_out * kernel_height * kernel_width
ds_int.append(1) #channels_out * kernel_height * kernel_width

ds_array = [] #this is for the NTK formulation, 
#ds_array.append(torch.tensor([1.0],dtype=torch.float32).to(device)) #first element is a spacer, could be anything.

ds_array.append(torch.tensor([1.0],dtype=torch.float32).to(device)) #first element is a spacer, could be anything.
ds_array.append(torch.tensor([1.0],dtype=torch.float32).to(device)) #The rest, even if you dont use NTK formulation, would be 1
ds_array.append(torch.tensor([1.0],dtype=torch.float32).to(device))
ds_array.append(torch.tensor([1.0],dtype=torch.float32).to(device))

filters = []
filters.append(0)
filters.append(0)
filters.append(0)
filters.append(0)
filters.append(0)
filters.append(0)


padding = []
padding.append(0)
padding.append(0)
padding.append(0)
padding.append(0)
padding.append(0)
padding.append(0)


strides = []
strides.append(0)
strides.append(0)
strides.append(0)
strides.append(0)
strides.append(0)
strides.append(0)


layers=[mymodel.d1,
        mymodel.d2,
        mymodel.d3,
        mymodel.d4,
       ]

In [209]:
fc = component_fisher(Ws, Ks, Xs, ds_int, ds_array, strides, padding, layers, d_activationt)

S:  torch.Size([1])
Xs:  torch.Size([10, 3])
torch.Size([10, 3])
W:  torch.Size([1, 10])
D:  torch.Size([10, 3])
S:  torch.Size([10, 3])
Xs:  torch.Size([10, 3])
torch.Size([10, 3])
W:  torch.Size([10, 10])
D:  torch.Size([10, 3])


RuntimeError: The size of tensor a (3) must match the size of tensor b (10) at non-singleton dimension 1

In [190]:
#earliest layers come last
I_fc_00 = fc[-1]@fc[-1].T
I_fc_01 = fc[-1]@fc[-2].T
I_fc_02 = fc[-1]@fc[-3].T
I_fc_03 = fc[-1]@fc[-4].T

I_fc_10 = fc[-2]@fc[-1].T
I_fc_11 = fc[-2]@fc[-2].T
I_fc_12 = fc[-2]@fc[-3].T
I_fc_13 = fc[-2]@fc[-4].T

I_fc_20 = fc[-3]@fc[-1].T
I_fc_21 = fc[-3]@fc[-2].T
I_fc_22 = fc[-3]@fc[-3].T
I_fc_23 = fc[-3]@fc[-4].T

I_fc_30 = fc[-4]@fc[-1].T
I_fc_31 = fc[-4]@fc[-2].T
I_fc_32 = fc[-4]@fc[-3].T
I_fc_33 = fc[-4]@fc[-4].T

Ifc_pieces_0 = torch.cat([I_fc_00,I_fc_01,I_fc_02,I_fc_03],axis=1) #100, 341
Ifc_pieces_1 = torch.cat([I_fc_10,I_fc_11,I_fc_12,I_fc_13],axis=1) #100, 341
Ifc_pieces_2 = torch.cat([I_fc_20,I_fc_21,I_fc_22,I_fc_23],axis=1) #100, 341
Ifc_pieces_3 = torch.cat([I_fc_30,I_fc_31,I_fc_32,I_fc_33],axis=1) #100, 341

Ifc_pieces = torch.cat([Ifc_pieces_0,Ifc_pieces_1,Ifc_pieces_2,Ifc_pieces_3],axis=0) #341, 341



In [191]:
Ifc_pieces.shape

torch.Size([40, 40])

# Develop Autograd because we need to check if these preliminary matrices are correct anyways

In [257]:
y = mymodel(x_test)[0]

In [258]:
y.shape

torch.Size([3, 1])

In [269]:
#in the future we would iterate over layers instead of like this...
layer_components_w1 = [] 
layer_components_w2 = []
layer_components_w3 = []
layer_components_w4 = []

layer_components_b1 = []
layer_components_b2 = []
layer_components_b3 = []
layer_components_b4 = []

for i in range(len(y)):#range(len(y)):
    mymodel.zero_grad()
    y[i,0].backward(retain_graph=True)
    #Get the tensors

    #reshape and append. deep copy neccessary or else they are the same objects
    layer_components_w1.append(mymodel.d1.weight.grad.detach().reshape(-1).clone())
    layer_components_w2.append(mymodel.d2.weight.grad.detach().reshape(-1).clone())
    layer_components_w3.append(mymodel.d3.weight.grad.detach().reshape(-1).clone())
    layer_components_w4.append(mymodel.d4.weight.grad.detach().reshape(-1).clone())
    
J_w1 = torch.stack(layer_components_w1).T
J_w2 = torch.stack(layer_components_w2).T 
J_w3 = torch.stack(layer_components_w3).T
J_w4 = torch.stack(layer_components_w4).T 

J = torch.cat([J_w1,J_w2,J_w3,J_w4],dim=0)

In [270]:
NTK = (J.T@J) #N x N matrix
NTK.shape

torch.Size([3, 3])

In [271]:
torch.linalg.eigvalsh(NTK)

tensor([1.7999e-10, 4.1165e-03, 7.8675e+00], device='cuda:0')

In [272]:
I = (J@J.T) # P x P matrix

In [273]:
torch.linalg.eigvalsh(I)[-3::] #these are, to machine precision.... which I guess is like 10e-6;
#so I guess that has implications for out NTK eigenvalue anaysis

tensor([3.5996e-07, 4.1161e-03, 7.8675e+00], device='cuda:0')

In [274]:
J_w1.shape

torch.Size([100, 3])

In [147]:
I_00 = J_w1@J_w1.T
I_01 = J_w1@J_w1.T
I_02 = J_w1@J_w1.T
I_03 = J_w1@J_w1.T

I_10 = J_w1@J_w1.T
I_11 = J_w1@J_w1.T
I_12 = J_w1@J_w1.T
I_13 = J_w1@J_w1.T

I_20 = J_w1@J_w1.T
I_21 = J_w1@J_w1.T
I_22 = J_w1@J_w1.T
I_23 = J_w1@J_w1.T

I_30 = J_w1@J_w1.T
I_31 = J_w1@J_w1.T
I_32 = J_w1@J_w1.T
I_33 = J_w1@J_w1.T

Ifc_pieces_0 = torch.cat([I_fc_00,I_fc_01,I_fc_02,I_fc_03],axis=1) #100, 341
Ifc_pieces_1 = torch.cat([I_fc_10,I_fc_11,I_fc_12,I_fc_13],axis=1) #100, 341
Ifc_pieces_2 = torch.cat([I_fc_20,I_fc_21,I_fc_22,I_fc_23],axis=1) #100, 341
Ifc_pieces_3 = torch.cat([I_fc_30,I_fc_31,I_fc_32,I_fc_33],axis=1) #100, 341

Ifc_pieces = torch.cat([Ifc_pieces_0,Ifc_pieces_1,Ifc_pieces_2,Ifc_pieces_3],axis=0) #341, 341




In [148]:
I == I_pieces

tensor([[True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True],
        ...,
        [True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True]], device='cuda:0')

In [None]:
I = J[0:100,0:100]

In [130]:
print(I.detach().cpu().numpy())

[[ 1.2020423e-03  5.4884329e-04  7.9924171e-04 ...  4.5533446e-03
  -2.4962872e-03 -5.6376527e-03]
 [ 5.4884329e-04  2.9127533e-03  1.6404245e-03 ...  4.2535797e-02
  -2.1792766e-02 -5.3483039e-02]
 [ 7.9924171e-04  1.6404245e-03  1.1425368e-03 ...  2.2411257e-02
  -1.1555082e-02 -2.8140074e-02]
 ...
 [ 4.5533446e-03  4.2535797e-02  2.2411257e-02 ...  6.3206965e-01
  -3.2331926e-01 -7.9501843e-01]
 [-2.4962872e-03 -2.1792766e-02 -1.1555082e-02 ... -3.2331926e-01
   1.6540968e-01  4.0665877e-01]
 [-5.6376527e-03 -5.3483039e-02 -2.8140074e-02 ... -7.9501843e-01
   4.0665877e-01  9.9998254e-01]]


In [53]:
layer_components_w1.shape #[P x N]

torch.Size([100, 3])

In [54]:
X = layer_components_w1 @ layer_components_w1.T  #this should be left in the torch tensor.

In [131]:
X

torch.Size([100, 100])

In [64]:
fisher_components[-5].shape

torch.Size([10, 10])

# what about improvements to torch.autograd.jacobian

In [92]:
y2 = mymodel(x_test)

In [97]:
y2[-1].shape

torch.Size([3, 10])

In [167]:
J1 = torch.autograd.functional.jacobian(mymodel,y2[-1])

In [177]:
J1[4]

tensor([[[[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

         [[0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

         [[0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

         [[0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

         [[0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

         [[0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

         [[0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],


# Layerwise Autograd Method for arbitary pytorch models

In [275]:
class model2(torch.nn.Module):
    def __init__(self,):
        super(model2, self).__init__()
        
        self.d1 = torch.nn.Linear(10,10,bias=False)

        self.d2 = torch.nn.Linear(10,10,bias=False)
        
        self.d3 = torch.nn.Linear(10,10,bias=False)
        
        self.d4 = torch.nn.Linear(10,1,bias=False)
        
    def forward(self, x_0):
        x_1 = activation(self.d1(x_0))
        x_2 = activation(self.d2(x_1))
        x_3 = activation(self.d3(x_2))
        x_4 = activation(self.d4(x_3))
        return x_4

In [302]:
mymodel2 = model2()
mymodel2.apply(NTK_weights)
mymodel2.to(device)

x_test2 = np.random.normal(0,1,(how_many,10)).astype(np.float32) #n c_in, h, w
x_test2 = torch.from_numpy(x_test2)
x_test2 = x_test2.to(device)

torch.Size([10, 10])
torch.Size([10, 10])
torch.Size([10, 10])
torch.Size([1, 10])


In [365]:
y = mymodel2(x_test2)[:,0] #defined to only work for networks that terminate to single neuron, so this is okay

In [6]:
def autograd_components_NTK(model: torch.nn.Module, y :torch.tensor):
    
    NTKs = {}
    
    if len(y.shape) > 1:
        raise ValueError('y must be 1-D, but its shape is: {}'.format(y.shape))
    
    params_that_need_grad = []
    for param in model.parameters():
        if param.requires_grad:
            params_that_need_grad.append(param.requires_grad)
            #first set all gradients to not calculate, time saver
            param.requires_grad = False
        else:
            params_that_need_grad.append(param.requires_grad)

    #how do we parallelize this operation across multiple gpus or something? that be sweet.
    for i,z in enumerate(model.named_parameters()):
        if not(params_that_need_grad[i]): #if it didnt need a grad, we can skip it.
            continue
        name, param = z
        param.requires_grad = True #we only care about this tensors gradients in the loop
        this_grad=[]
        for i in range(len(y)): #first dimension must be the batch dimension
            model.zero_grad()
            y[i].backward(create_graph=True)
            this_grad.append(param.grad.detach().reshape(-1).clone())

        J_layer = torch.stack(this_grad) # [N x P matrix] #this will go against our notation, but I'm not adding

        NTKs[name] = J_layer @ J_layer.T # An extra transpose operation to my code for us to feel better

        param.requires_grad = False
     
    #reset the model object to be how we started this function
    for i,param in enumerate(model.parameters()):
        if params_that_need_grad[i]:
            param.requires_grad = True #

    return NTKs

def autograd_components_Fisher(model: torch.nn.Module, y :torch.tensor):
    
    fishers = {}
    
    if len(y.shape) > 1:
        raise ValueError('y must be 1-D, but its shape is: {}'.format(y.shape))
    
    params_that_need_grad = []
    for param in model.parameters():
        if param.requires_grad:
            params_that_need_grad.append(param.requires_grad)
            #first set all gradients to not calculate, time saver
            param.requires_grad = False
        else:
            params_that_need_grad.append(param.requires_grad)

    #how do we parallelize this operation across multiple gpus or something? that be sweet.
    for i,z in enumerate(model.named_parameters()):
        if not(params_that_need_grad[i]): #if it didnt need a grad, we can skip it.
            continue
        name, param = z
        param.requires_grad = True #we only care about this tensors gradients in the loop
        this_grad=[]
        for i in range(len(y)): #first dimension must be the batch dimension
            model.zero_grad()
            y[i].backward(create_graph=True)
            this_grad.append(param.grad.detach().reshape(-1).clone())

        J_layer = torch.stack(this_grad) # [N x P matrix] #this will go against our notation, but I'm not adding

        fishers[name] = J_layer.T @ J_layer # An extra transpose operation to my code for us to feel better

        param.requires_grad = False
     
    #reset the model object to be how we started this function
    for i,param in enumerate(model.parameters()):
        if params_that_need_grad[i]:
            param.requires_grad = True #

    return fishers 

def reconstruct_full_fisher_from_components(fisher):
    raise ValueError('Not Implemented Yet')
    #given the dictionary of fisher information matrices, we can 
    #combine them to get the full fisher information matrix.
    return

def autograd_components_Jacobian(model: torch.nn.Module, y :torch.tensor):
    
    Jacobians = {}
    
    if len(y.shape) > 1:
        raise ValueError('y must be 1-D, but its shape is: {}'.format(y.shape))
    
    params_that_need_grad = []
    for param in model.parameters():
        if param.requires_grad:
            params_that_need_grad.append(param.requires_grad)
            #first set all gradients to not calculate, time saver
            param.requires_grad = False
        else:
            params_that_need_grad.append(param.requires_grad)

    #how do we parallelize this operation across multiple gpus or something? that be sweet.
    for i,z in enumerate(model.named_parameters()):
        if not(params_that_need_grad[i]): #if it didnt need a grad, we can skip it.
            continue
        name, param = z
        param.requires_grad = True #we only care about this tensors gradients in the loop
        this_grad=[]
        for i in range(len(y)): #first dimension must be the batch dimension
            model.zero_grad()
            y[i].backward(create_graph=True)
            this_grad.append(param.grad.detach().reshape(-1).clone())

        J_layer = torch.stack(this_grad) # [N x P matrix] #this will go against our notation, but I'm not adding

        Jacobians[name] = J_layer # An extra transpose operation to my code for us to feel better

        param.requires_grad = False
     
    #reset the model object to be how we started this function
    for i,param in enumerate(model.parameters()):
        if params_that_need_grad[i]:
            param.requires_grad = True #

    return Jacobians


        
def autograd_NTK(model: torch.nn.Module, y :torch.tensor):
    
    NTK = False
    
    if len(y.shape) > 1:
        raise ValueError('y must be 1-D, but its shape is: {}'.format(y.shape))
    
    params_that_need_grad = []
    for param in model.parameters():
        if param.requires_grad:
            params_that_need_grad.append(param.requires_grad)
            #first set all gradients to not calculate, time saver
            param.requires_grad = False
        else:
            params_that_need_grad.append(param.requires_grad)

    #how do we parallelize this operation across multiple gpus or something? that be sweet.
    for i,z in enumerate(model.named_parameters()):
        if not(params_that_need_grad[i]): #if it didnt need a grad, we can skip it.
            continue
        name, param = z
        param.requires_grad = True #we only care about this tensors gradients in the loop
        this_grad=[]
        for i in range(len(y)): #first dimension must be the batch dimension
            model.zero_grad()
            y[i].backward(create_graph=True)
            this_grad.append(param.grad.detach().reshape(-1).clone())

        J_layer = torch.stack(this_grad) # [N x P matrix] #this will go against our notation, but I'm not adding

        if (type(NTK) is bool) and not(NTK):
            NTK = J_layer @ J_layer.T # An extra transpose operation to my code for us to feel better
        else:
            NTK += J_layer @ J_layer.T

        param.requires_grad = False
     
    #reset the model object to be how we started this function
    for i,param in enumerate(model.parameters()):
        if params_that_need_grad[i]:
            param.requires_grad = True #

    return NTK

# What about something like this?

In [12]:
sum(dict((p.data_ptr(), p.numel()) for p in vgg11.parameters()).values())

2225153

In [7]:
import torchvision.models as models
vgg11 = models.mobilenet_v2(num_classes=1)
vgg11.to('cuda')

MobileNetV2(
  (features): Sequential(
    (0): ConvBNActivation(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU6(inplace=True)
    )
    (1): InvertedResidual(
      (conv): Sequential(
        (0): ConvBNActivation(
          (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
          (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (2): InvertedResidual(
      (conv): Sequential(
        (0): ConvBNActivation(
          (0): Conv2d(16, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(96, eps=1e-05, momen

In [8]:
x_train3 = torch.empty(100,3,32,32).normal_(0,1).to('cuda')

In [9]:
y = vgg11(x_train3)[:,0]

In [10]:
now=time.time()
NTK = autograd_NTK(vgg11, y)
print(time.time()-now)

501.2134418487549


In [11]:
NTK.shape

torch.Size([100, 100])

# $$I = ( \nabla f_\theta(x)) ( \nabla f_\theta(x)^\top) $$