In [4]:
import numpy as np
import torch
import random

import matplotlib.pyplot as plt


import torch
from torch import nn, optim
from torch import load
from torch.nn import functional as F
from torch import autograd

from torchvision import datasets

from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader

import time

import sys
from pathlib import Path

from numba import njit

import os

import gc

In [5]:
from torchntk.explicit import explicit_ntk

torch.vmap is currently available only on nightly releases: torch version  1.11.0.dev20220127


In [6]:
experiment_base_path = './experiments/'

d1_path = os.path.join(experiment_base_path,'d1_component')
d2_path = os.path.join(experiment_base_path,'d2_component')
d3_path = os.path.join(experiment_base_path,'d3_component')
d4_path = os.path.join(experiment_base_path,'d4_component')

ntk_path = os.path.join(experiment_base_path,'ntk')

In [7]:
os.makedirs(d1_path,exist_ok=True)
os.makedirs(d2_path,exist_ok=True)
os.makedirs(d3_path,exist_ok=True)
os.makedirs(d4_path,exist_ok=True)
os.makedirs(ntk_path,exist_ok=True)

In [8]:
SEED = 0
N_EPOCHS=100
LR=1e-2

In [9]:
def activation(x):
    return torch.tanh(x)

def d_activationt(x):
    return torch.cosh(x)**-2

In [10]:
def NTK_weights(m):
    if isinstance(m, nn.Linear):
        print(m.weight.shape)
        nn.init.normal_(m.weight.data)#/m.weight.shape[0]
        if m.bias != None:
            nn.init.normal_(m.bias.data)#/m.weight.shape[0]
    if isinstance(m, nn.Conv2d):
        print(m.weight.shape)
        nn.init.normal_(m.weight.data)#/m.weight.shape[0]
        if m.bias != None:
            nn.init.normal_(m.bias.data)#/m.weight.shape[0]

In [11]:
class FC(torch.nn.Module):
    '''
    simple network for test cases
    
    
    It seems like bias vectors aren't trivially added.
    '''
    def __init__(self,):
        super(FC, self).__init__()
        #input size=(N,784)
        self.d1 = torch.nn.Linear(784,100,bias=False)

        self.d2 = torch.nn.Linear(100,100,bias=False)
        
        self.d3 = torch.nn.Linear(100,100,bias=False)
        
        self.d4 = torch.nn.Linear(100,1,bias=False)
        
    def forward(self, x0):
        x1 = 1/np.sqrt(100) * activation(self.d1(x0))
        x2 = 1/np.sqrt(100) * activation(self.d2(x1))
        x3 = 1/np.sqrt(100) * activation(self.d3(x2))
        x4 = self.d4(x3)
        return x4, x3, x2, x1, x0

In [12]:
torch.manual_seed(SEED)
random.seed(SEED)
np.random.seed(SEED)
device='cuda'

model = FC()
model.to(device)
model.apply(NTK_weights)

torch.Size([100, 784])
torch.Size([100, 100])
torch.Size([100, 100])
torch.Size([1, 100])


FC(
  (d1): Linear(in_features=784, out_features=100, bias=False)
  (d2): Linear(in_features=100, out_features=100, bias=False)
  (d3): Linear(in_features=100, out_features=100, bias=False)
  (d4): Linear(in_features=100, out_features=1, bias=False)
)

In [13]:
train_data = datasets.MNIST(
    root = './DATA/',
    train = True,                          
    download = True,            
)


test_data = datasets.MNIST(
    root = './DATA/', 
    train = False, 
    download=True,
)

In [14]:
train_x = train_data.data
test_x = test_data.data

train_y = train_data.targets
test_y = test_data.targets

train_mask = torch.logical_or(train_y==6,train_y==9)
test_mask = torch.logical_or(test_y==6,test_y==9)

train_x = train_x[train_mask]
train_y = train_y[train_mask]

test_x = test_x[test_mask]
test_y = test_y[test_mask]

train_x = train_x/255.0
test_x = test_x/255.0

train_y[train_y==6] = 0
train_y[train_y==9] = 1

test_y[test_y==6] = 0
test_y[test_y==9] = 1

train_x = train_x.reshape(-1,784)
test_x = test_x.reshape(-1,784)

train_y = train_y.float()
test_y = test_y.float()

train_x = train_x - torch.mean(train_x)
train_x = train_x / torch.std(train_x)

test_x = test_x - torch.mean(train_x)
test_x = test_x / torch.std(test_x)

mask_6 = train_y == 0
mask_9 = train_y == 1

train_x_6 = train_x[mask_6]
train_x_9 = train_x[mask_9]

train_y_6 = train_y[mask_6]
train_y_9 = train_y[mask_9]

mask_6 = test_y==0
mask_9 = test_y==1

test_x_6 = test_x[mask_6]
test_x_9 = test_x[mask_9]

test_y_6 = test_y[mask_6]
test_y_9 = test_y[mask_9]

#grab 2000 from each, turn that back into train x and train y, which is now sorted.

train_y = torch.cat([train_y_6[0:5000],train_y_9[0:5000]])
train_x = torch.cat([train_x_6[0:5000],train_x_9[0:5000]])

test_x = torch.cat([test_x_6,test_x_9])
test_y = torch.cat([test_y_6,test_y_9])

train_x = train_x.to('cuda')
train_y = train_y.to('cuda')

test_x = test_x.to('cuda')
test_y = test_y.to('cuda')

In [15]:
criterion = torch.nn.BCEWithLogitsLoss()

In [9]:
def forward(model, x_test):
    x_4, x_3, x_2, x_1, x_0 = model(x_test)

    Xs = [] # Xs are shape (output x #DP) ; however, typical python notation is reversed, so we take transpose here
    Xs.append(x_0.T.detach())
    Xs.append(x_1.T.detach())
    Xs.append(x_2.T.detach())
    Xs.append(x_3.T.detach())

    #This is used to create arrays-- needs to be integer list to play nice with compilers
    d_int = []
    d_int.append(100)
    d_int.append(100)
    d_int.append(100)
    d_int.append(1)

    d_array = [] #this is for the NTK formulation, 
    #ds_array.append(torch.tensor([1.0],dtype=torch.float32).to(device)) #first element is a spacer, could be anything.

    d_array.append(torch.tensor([100.0],dtype=torch.float32).to(device)) 
    d_array.append(torch.tensor([100.0],dtype=torch.float32).to(device)) 
    d_array.append(torch.tensor([100.0],dtype=torch.float32).to(device))
    d_array.append(torch.tensor([1.0],dtype=torch.float32).to(device))

    layers=[model.d1,
            model.d2,
            model.d3,
            model.d4,
           ]
    
    #Ws Ks Xs d_int d_array strides padding layers d_activationt, device="cuda"
    return x_4[:,0], {'Xs':Xs, 'd_int':d_int, 'd_array':d_array, 'layers':layers, 'd_activationt':d_activationt, 'device':'cuda'}

In [17]:
def visualize_spectrum(NTK_component,epoch,save_str):
    '''
    takes a numpy array and saves the image to save_str
    '''
    fig = plt.figure(figsize=(8,8))
    plt.hist(torch.linalg.eigvalsh(NTK_component).detach().cpu().numpy(), bins=np.logspace(-14,4,100,10.0))
    med = torch.mean(NTK_component)
    plt.xscale('log')
    plt.xlabel('eigenvalue')
    plt.title('MNIST-2 {} eigenvalue spectrum, step={}'.format(save_str,str(epoch).zfill(5)))
    plt.vlines(med.detach().cpu().numpy(),0,200,color='k',linestyle='dashed')
    #plt.savefig(os.path.join(experiment_base_path,'images',save_str,str(epoch).zfill(5)+'.png'))
    plt.savefig(os.path.join('./images','spectrumMNIST2'+save_str+'_'+str(epoch).zfill(5)+'.pdf'))
    plt.close()
    #plt.show()
    
def visualize_spectrum_linear(NTK_component,epoch,save_str):
    '''
    takes a numpy array and saves the image to save_str
    '''
    fig = plt.figure(figsize=(8,8))
    plt.hist(torch.linalg.eigvalsh(NTK_component).detach().cpu().numpy(), bins=np.linspace(1e-14,1e4,100,10.0))
    plt.xlabel('eigenvalue')
    plt.title('{} eigenvalue spectrum: {}'.format(save_str,str(epoch).zfill(5)))
    plt.savefig(os.path.join(experiment_base_path,'images',save_str,str(epoch).zfill(5)+'.png'))
    plt.close()
    #plt.show()
    
def visualize_ntk_matrix(NTK_component,epoch,save_str):
    '''
    plots a heatmap of the NTK, and does so without changing the values
    '''
    fig = plt.figure(figsize=(8,8))
    plt.imshow(NTK_component.detach().cpu().numpy(),cmap='Greys')
    plt.colorbar()
    plt.xlabel('NTK index j')
    plt.ylabel('NTK index i')
    plt.title('{} matrix map: {}'.format(save_str,str(epoch).zfill(5)))
    plt.show()
    #plt.savefig(os.path.join(experiment_base_path,'images_matrixmap',save_str,str(epoch).zfill(5)+'.png'))
    #plt.close()
    
def visualize_ntk_matrix_scaled(NTK_component,epoch,save_str):
    '''
    plots a heatmap of the NTK, and does so by scaling the data first to be between -1 and 1
    '''
    NTK_component = NTK_component.detach().cpu().numpy()
    #scale between 0, 1
    vmin, vmax = np.quantile(NTK_component,[0.01,0.99])
    #
    fig = plt.figure(figsize=(8,8))
    plt.imshow(NTK_component,cmap='Greys',vmin=vmin, vmax=vmax)
    plt.colorbar()
    plt.xlabel('NTK index j')
    plt.ylabel('NTK index i')
    plt.title('MNIST-2 {} matrix map, step={}'.format(save_str,str(epoch).zfill(5)))
    
    #plt.show()
    #plt.savefig(os.path.join(experiment_base_path,'images_matrixmapscaled',save_str,str(epoch).zfill(5)+'.png'))
    plt.savefig(os.path.join('./images','matrixMNIST2_rel_'+save_str+'_'+str(epoch).zfill(5)+'.pdf'))
    plt.close()
    
def visualize_ntk_matrix_absscaled(NTK_component,epoch,save_str,vmin,vmax):
    '''
    plots a heatmap of the NTK, and does so by scaling the data first to be between -1 and 1
    '''
    NTK_component = NTK_component.detach().cpu().numpy()
    #
    fig = plt.figure(figsize=(8,8))
    plt.imshow(NTK_component,cmap='Greys',vmin=vmin, vmax=vmax)
    plt.colorbar()
    plt.xlabel('NTK index j')
    plt.ylabel('NTK index i')
    plt.title('MNIST-2 {} matrix map, step={}'.format(save_str,str(epoch).zfill(5)))
    
    #plt.savefig(os.path.join('./images','matrixMNIST2_abs_'+save_str+'_'+str(epoch).zfill(5)+'.pdf'))
    #plt.close()
    #plt.show()
    plt.savefig(os.path.join(experiment_base_path,'images_matrixmapscaled',save_str,str(epoch).zfill(5)+'.png'))
    plt.close()
    
def distributions(NTK_component,epoch,save_str):
    '''
    plots the distribution of all values in sixs and all values in nines
    '''
    NTK_component = NTK_component.detach().cpu().numpy()
    
    
    #This would be used to determine if the value is six or nine, not the distribution.
    sixes_and_sixes = NTK_component[0:5000,0:5000]
    sixes_and_nines= NTK_component[0:5000,5000::]
                            
    nines_and_sixes= NTK_component[5000::,0:5000]
    nines_and_nines = NTK_component[5000::,5000::]
    
    #Now  plot two things: first the distribution of 6s with 6s and the distribution of 6s with 9s
    fig = plt.figure(figsize=(8,8))
    bins = np.linspace(np.quantile(NTK_component[0:5000,:],0.001), np.quantile(NTK_component[0:5000,:],0.999), 300)
    plt.hist(sixes_and_sixes.flatten(),histtype='step',color='tab:blue',label='6s with 6s',bins=bins)
    plt.hist(sixes_and_nines.flatten(),histtype='step',color='tab:orange',label='6s with 9s',bins=bins)
    plt.title('MNIST-2 {} Histogram of 6s, step={}'.format(save_str,str(epoch).zfill(5)))
    plt.legend()
    plt.savefig(os.path.join('./images','distributionMNIST2_'+save_str+'_'+str(epoch).zfill(5)+'.png'))
    plt.close()
    #plt.show()
    
def accuracy_kernels(model,x,y,X,Y):
    '''
    According to a user, matrix inversion of 10k x 10k matrices is a very costly operation.
    
    And this is a complex operation since it requires calculating the Jacobian of x and having the Jacobian of X
    handy. I dont store those numbers, so if you wanted to be useful you would need to 
    '''
    index_test = len(x)
    __, params = forward(model,torch.cat([x,X]))
    kernel_components = explicit_ntk(**params)
    Y_prime = Y.clone()
    Y_prime[Y_prime==0] = -1
    y_prime = y.clone()
    y_prime[y_prime==0] = -1
    layer=1
    for i in reversed(range(len(kernel_components))):
        f_x = torch.sign(torch.sum(Y_prime*kernel_components[i][0:index_test,index_test::],dim=1))
        #should be len index_test vector
        correct = torch.sum(f_x == y_prime)
        print('layer {} kernel acc: {:.2f}'.format(i+1,100*correct/index_test))
    ntk = torch.sum(torch.stack(kernel_components),dim=0)
    f_x = torch.sign(torch.sum(Y_prime*kernel_components[i][0:index_test,index_test::],dim=1))
    correct = torch.sum(f_x == y_prime)
    print('full ntk kernel acc: {:.2f}'.format(100*correct/index_test))
    
    del kernel_components[:]
    del kernel_components
    torch.cuda.empty_cache()

In [18]:
from tqdm import tqdm

In [23]:
losses=[]
accuracies=[]

step=0
#while True:
for step in tqdm(range(1,N_EPOCHS+1)):
    now = time.time()
    
    model.zero_grad()
    
    output, params = forward(model, train_x)
    
    if step==1 or step==N_EPOCHS:
        components = explicit_ntk(**params)
        
        visualize_ntk_matrix_absscaled(components[-4],step,'layer4',-0.3367,0.4031)
        visualize_ntk_matrix_absscaled(components[-3],step,'layer3',-0.2959,0.7079)
        visualize_ntk_matrix_absscaled(components[-2],step,'layer2',-0.3990,1.8657)
        visualize_ntk_matrix_absscaled(components[-1],step,'layer1',-13.1301,421.7019)
        visualize_ntk_matrix_absscaled(torch.sum(torch.stack(components),dim=0),step,'ntk',-13.1301,421.7019)
        
        del components[:]
        del components
        torch.cuda.empty_cache()
        
    if step==1 or step==N_EPOCHS:
        accuracy_kernels(model,test_x,test_y,train_x,train_y)
    
    loss = criterion(output,train_y)
    loss.backward()
    losses.append(loss.item())
        
    with torch.no_grad():
        for name,W in model.named_parameters():
            W-= LR * W.grad
            
    predictions = torch.round(torch.sigmoid(output))
    acc = torch.sum(predictions==train_y)
    accuracies.append(acc.detach().cpu().numpy()/len(train_y))
    print(time.time() - now)
    

 75%|█████████████████████████████████████████████████████████████▌                    | 75/100 [00:14<00:03,  7.11it/s]

layer 4 kernel acc: 89.27
layer 3 kernel acc: 95.02
layer 2 kernel acc: 93.39
layer 1 kernel acc: 91.31
full ntk kernel acc: 91.31
14.860041856765747
0.0013391971588134766
0.0013425350189208984
0.001352548599243164
0.0013248920440673828
0.0013248920440673828
0.001329660415649414
0.0013113021850585938
0.0013134479522705078
0.001323699951171875
0.001316070556640625
0.001316070556640625
0.0013124942779541016
0.001310110092163086
0.0013203620910644531
0.0013186931610107422
0.0013124942779541016
0.001337289810180664
0.0013284683227539062
0.0013120174407958984
0.001314401626586914
0.0013141632080078125
0.0013194084167480469
0.0013148784637451172
0.001313924789428711
0.0013229846954345703
0.001325845718383789
0.001310110092163086
0.0013229846954345703
0.0013251304626464844
0.001317739486694336
0.0013120174407958984
0.0013213157653808594
0.0013124942779541016
0.0013134479522705078
0.001323699951171875
0.001312255859375
0.0013234615325927734
0.0013217926025390625
0.0013146400451660156
0.0013203

100%|█████████████████████████████████████████████████████████████████████████████████| 100/100 [00:29<00:00,  3.37it/s]

layer 4 kernel acc: 89.32
layer 3 kernel acc: 95.17
layer 2 kernel acc: 93.39
layer 1 kernel acc: 91.46
full ntk kernel acc: 91.46
14.679290771484375





In [22]:
plt.plot(losses)
plt.title('MNIST-2 Training Loss')
plt.ylabel('Loss Value')
plt.xlabel('Step')
plt.savefig(os.path.join(experiment_base_path,'MNIST2_losses.pdf'))
plt.close()

plt.plot(accuracies,label='train')
plt.title('MNIST-2 Training Accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Step')
plt.ylim(-0.01,1.01)
plt.savefig(os.path.join(experiment_base_path,'MNIST2_accuracies.pdf'))
plt.close()

print(losses[-1])

0.6663719415664673


In [None]:
print(losses[0])
print(losses[-1])

print(accuracies[0])
print(accuracies[-1])