In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
from torch import nn
from torch import optim
from torch.nn import functional as F
import numpy as np
from numpy import linalg as LA
import matplotlib.pyplot as plt
import warnings
import math
import time
warnings.filterwarnings("ignore")

In [None]:
def generate_X_all(data_dim = 2000, training_data_num = 12000, test_data_num = 2000):
    
    # generate raw data
    mean_raw = np.zeros(data_dim)
    dev_raw = 1
    cov_raw = dev_raw * np.identity(data_dim)
    X_all_raw = np.random.multivariate_normal(mean_raw, cov_raw, training_data_num + test_data_num)
    
    # normalize to a sphere -------------------
    X_all_sphere = np.zeros((training_data_num + test_data_num, data_dim))
    for i in range(training_data_num + test_data_num):
        X_all_sphere[i,:] = X_all_raw[i,:] / np.linalg.norm(X_all_raw[i,:], ord=2)
    # -----------------------------------------
    
    X_all_train, X_all_test = X_all_sphere[:training_data_num, :], X_all_sphere[training_data_num:, :]
    
    return X_all_train, X_all_test

In [None]:
data_dim = 50
training_data_num = 2000
test_data_num = 1000
N = 8000

create_new_data = True # Whether to create new data or load existing data
save_data = True # Whether to save the data created
label_noise = True #False # Whether to add label noise
mu = 0
sigma = np.sqrt(1) # mean and standard deviation of the label noise if added

if create_new_data:
    X_all_train, X_all_test = generate_X_all(data_dim, training_data_num, test_data_num)
    X_all_train_np = X_all_train

    y_dim = 1
    A = np.random.normal(0, 1, size=(data_dim, y_dim))
    Y_all_train = np.dot(X_all_train, A)
    Y_all_test = np.dot(X_all_test, A)

    if label_noise:
        noise = np.random.normal(mu, sigma, (training_data_num, y_dim))
        Y_all_train = Y_all_train + noise

    print(X_all_train.shape)
    print(X_all_test.shape)
    print(Y_all_train.shape)
    print(Y_all_test.shape)
    if save_data:
        np.save('X_all_train.npy', X_all_train_np)
        np.save('X_all_test.npy', X_all_test)
        np.save('Y_all_train.npy', Y_all_train)
        np.save('Y_all_test.npy', Y_all_test)
else:
    X_all_train_np = np.load('X_all_train.npy')
    X_all_test = np.load('X_all_test.npy')
    Y_all_train = np.load('Y_all_train.npy')
    Y_all_test = np.load('Y_all_test.npy')
    print(X_all_train_np.shape)
    print(X_all_test.shape)
    print(Y_all_train.shape)
    print(Y_all_test.shape)

#generate Q
Q1, Q2= generate_X_all(data_dim, np.rint(N/2).astype(int), np.rint(N/2).astype(int))
Q = np.concatenate((Q1, Q2), axis=0)
#print('Q shape is ' + str(Q.shape))

In [None]:
# convert to torch tensor
X_all_train = torch.FloatTensor(X_all_train)
X_all_test = torch.FloatTensor(X_all_test)
Y_all_train = torch.FloatTensor(Y_all_train)
Y_all_test = torch.FloatTensor(Y_all_test)
Q = torch.FloatTensor(Q)

In [None]:
# Define the network class
class precond_net(torch.nn.Module):
    def __init__(self, dim_in, num_neurons):
        super(precond_net, self).__init__()
        # first layer
        self.num_neurons = num_neurons
        self.dense_w = torch.nn.Linear(dim_in, num_neurons, bias=False)
        
        # activation
        self.relu = torch.nn.ReLU()
        
        # second layer with fixed weights
        weight_a_np = np.random.randint(2, size=(num_neurons, 1))
        weight_a_np[np.where(weight_a_np==0)]=-1
        weight_a_tensor = torch.from_numpy(weight_a_np)
        weight_a_tensor = weight_a_tensor.to(torch.float)
        self.weight_a = torch.nn.Parameter(weight_a_tensor, requires_grad=False)
        # print(self.weight_a.shape)
        # initialize first layer
        self._init_linear()
        
    def _init_linear(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                stdv = 1. / math.sqrt(m.weight.size(1))
                m.weight.data.uniform_(-stdv, stdv)
        

    def forward(self, x):
        y = self.dense_w(x)
        y = self.relu(y)
        y = torch.matmul(y, self.weight_a)
        y = y/math.sqrt(self.num_neurons)
        return y

In [None]:

# Function to compute the NTK
def compute_NTK(test_data,train_data):
    data_cor = np.matmul(test_data, np.moveaxis(train_data,-1,-2))
    data_cor[data_cor > 1.0] = 1.0
    data_cor[data_cor < -1.0] = -1.0
    arccos_data_cor = np.arccos(data_cor)
    ntk = 0.5*data_cor - 1.0/(2*np.pi)*np.multiply(data_cor,arccos_data_cor)
    return ntk

# Function to compute H matrix used in back-propagation
def get_H(x, w, a):
    # x: n x d
    # w: m x d
    # a: m x 1
    m = w.shape[0]
    n = x.shape[0]
    d = x.shape[1]
    
    H = torch.unsqueeze(x.permute(1,0), 1) # d x 1 x n
    H = H.expand(d, m, n) # d x m x n
    
    sign_matrix = torch.matmul(w, x.permute(1,0)) # m x n
    sign = torch.where(sign_matrix >= 0, 1, 0) # m x n
    sign = torch.mul(sign, a)
    
    H = torch.mul(H, sign) # d x m x n
    H = H.permute(1,0,2)
    H = H / (n*np.sqrt(m))
    return H

In [None]:
# Function to train the basline model
def train_base(model, X_train, Y_train, X_test, Y_test, hidden_unit, learning_rate = 0.1, epochs = 1000, print_freq=100):
    
    criterion = nn.MSELoss()
    train_loss_log = []
    test_loss_log = []
    
    train_loss_final = 0
    test_loss_final = 0
    ending_threshold = 0.0001
    achieved = False

    data_num = X_train.shape[0]
    data_dim = X_train.shape[1]

    for epoch in range(epochs):
        H = get_H(X_train, model.dense_w.weight.data, model.weight_a.data)
        
        train_out = model(X_train)
        dy = train_out - Y_train
        dw = torch.matmul(H, dy)
        dw = torch.mean(dw, dim=-1)
        
        model.dense_w.weight.data = model.dense_w.weight.data-learning_rate * dw
        
        train_loss = criterion(train_out, Y_train)
        with torch.no_grad():
            test_out = model(X_test) 
            test_loss = criterion(test_out, Y_test)
            if (train_loss.item()<=ending_threshold) and (achieved==False): # stop training if the loss is small enough
                achieved=True
                print(f' Epoch[{epoch + 1}] Training Loss: {train_loss.item():.8} Test Loss: {test_loss.item():.8}')
                train_loss_final = train_loss.item()
                test_loss_final = test_loss.item()
                break
            train_loss_log.append(train_loss.item())
            test_loss_log.append(test_loss.item())
            if (epoch+1) % print_freq == 0:
                    print(f' Epoch[{epoch + 1}] Training Loss: {train_loss.item():.8} Test Loss: {test_loss.item():.8}')
            train_loss_final = np.min(train_loss_log) # train_loss.item()
            test_loss_final = np.min(test_loss_log)  # test_loss.item()
    return model, train_loss_log, test_loss_log, train_loss_final, test_loss_final

In [None]:
# function to train the kernel preconditioned model
def train_kernel_preconditioned(H_Q, model, X_train, Y_train, X_test, Y_test, hidden_unit, learning_rate = 0.1, epochs = 1000, print_freq=100, ending_threshold = 0.0001):
    '''
     arguments
            precondition_option: 'inverse' or 'regular'. 'inverse' means perform inverse computation on the kernel, and 'regular' means directly using the kernel.
            ending_threshold: stop training when the training loss is below this threshold.
    '''
    criterion = nn.MSELoss()
    train_loss_log = []
    test_loss_log = []
    
    train_loss_final = 0
    test_loss_final = 0
    achieved = False

    data_num = X_train.shape[0]
    data_dim = X_train.shape[1]

    for epoch in range(epochs):
        # ------------------------- Preconditioning -------------------------
        H = get_H(X_train, model.dense_w.weight.data, model.weight_a.data)
        neuron_num = H.shape[0]
        dim = H.shape[1]
        H = H.view(hidden_unit*data_dim, data_num)
        H_Q1 = H_Q*N
        #print(H_Q1[:,0])
        train_out = model(X_train)
        dy = train_out - Y_train
        Hy = torch.matmul(H,dy)
        Hy1 = torch.matmul(H_Q1.permute(1,0),Hy)
        dw = torch.matmul(H_Q1,Hy1)/(N)
        
        dw = dw.view(neuron_num,dim)
        #dw = torch.mean(dw, dim=-1)
        #del H 
        model.dense_w.weight.data = model.dense_w.weight.data-learning_rate * dw
        
        train_loss = criterion(train_out, Y_train)
        with torch.no_grad():
            test_out = model(X_test) 
            test_loss = criterion(test_out, Y_test)
            if (train_loss.item()<=ending_threshold) and (achieved==False): # stop training if the loss is small enough
                achieved=True
                print(f' Epoch[{epoch + 1}] Training Loss: {train_loss.item():.8} Test Loss: {test_loss.item():.8}')
                train_loss_final = train_loss.item()
                test_loss_final = test_loss.item()
                break
            train_loss_log.append(train_loss.item())
            test_loss_log.append(test_loss.item())
            if (epoch+1) % print_freq == 0:
                    print(f' Epoch[{epoch + 1}] Training Loss: {train_loss.item():.8} Test Loss: {test_loss.item():.8}')
            train_loss_final = np.min(train_loss_log) # train_loss.item()
            test_loss_final = np.min(test_loss_log)  # test_loss.item()
    return model, train_loss_log, test_loss_log, train_loss_final, test_loss_final

In [None]:
hidden_unit = 20000 #5000
net_init = precond_net(dim_in = data_dim, num_neurons = hidden_unit).state_dict()
H_Q = get_H(Q, net_init['dense_w.weight'], net_init['weight_a'])
H_Q  = H_Q.view(hidden_unit*data_dim, N)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(torch.version.cuda)
print('GPU is available: ' + str(torch.cuda.is_available()))
#for i in range(torch.cuda.device_count()):
#   print(torch.cuda.get_device_properties(i).name)
num_of_gpus = torch.cuda.device_count()
print('number of GPUs is: '+ str(num_of_gpus))
X_all_train = X_all_train.to(device)
Y_all_train = Y_all_train.to(device)
X_all_test = X_all_test.to(device)
Y_all_test = Y_all_test.to(device)
H_Q = H_Q.to(device)
#print(H_Q[:,0])

In [None]:
# Perform training with training data of size from `smallest_data_num` to `largest_data_num` with step `data_num_step`
smallest_data_num = 100
largest_data_num = 1000
data_num_step = 100

### Baseline Model

In [None]:
vanilla_train_loss_final_all = []
vanilla_test_loss_final_all = []


for data_num in range(smallest_data_num, largest_data_num+data_num_step, data_num_step):
    print(data_num)
    net = precond_net(dim_in = data_dim, num_neurons = hidden_unit)
    net.load_state_dict(net_init)
    net = net.to(device)

    trained_model, train_loss_log, test_loss_log, train_loss_final, test_loss_final = train_base(
    net, X_all_train[:data_num, :], Y_all_train[:data_num, :], X_all_test[:data_num, :], Y_all_test[:data_num, :], 
    hidden_unit = hidden_unit, learning_rate = 5, epochs =5000, print_freq=100)

    vanilla_train_loss_final_all.append(train_loss_final)
    vanilla_test_loss_final_all.append(test_loss_final)

    np.save('data/vanilla_train_'+str(data_num)+'.npy', np.array(train_loss_log))
    np.save('data/vanilla_test_'+str(data_num)+'.npy', np.array(test_loss_log))

    np.save('data/vanilla_train_loss_final_all.npy', np.array(vanilla_train_loss_final_all))
    np.save('data/vanilla_test_loss_final_all.npy', np.array(vanilla_test_loss_final_all))


In [None]:
data_num = smallest_data_num
X_train = X_all_train[:data_num, :]
net = precond_net(dim_in = data_dim, num_neurons = hidden_unit)
net.load_state_dict(net_init)
net = net.to(device)
H = get_H(X_train, net.dense_w.weight.data, net.weight_a.data)
print(H.shape)
print(net.dense_w.weight.data.shape)
import gc
torch.cuda.empty_cache()
gc.collect()


### Preconditioned Model

In [None]:
kernel_train_loss_final_all = []
kernel_test_loss_final_all = []

for data_num in range(smallest_data_num, largest_data_num+data_num_step, data_num_step):
    print(data_num)
    net = precond_net(dim_in = data_dim, num_neurons = hidden_unit)
    net.load_state_dict(net_init)
    net = net.to(device)

    

    trained_model, train_loss_log, test_loss_log, train_loss_final, test_loss_final = train_kernel_preconditioned(
    H_Q, net, X_all_train[:data_num, :], Y_all_train[:data_num, :], X_all_test[:data_num, :], Y_all_test[:data_num, :], 
    hidden_unit=hidden_unit, learning_rate=300, epochs =5000, print_freq=100)

    kernel_train_loss_final_all.append(train_loss_final)
    kernel_test_loss_final_all.append(test_loss_final)

    np.save('data/kernel_train_'+str(data_num)+'.npy', np.array(train_loss_log))
    np.save('data/kernel_test_'+str(data_num)+'.npy', np.array(test_loss_log))

    np.save('data/kernel_train_loss_final_all.npy', np.array(kernel_train_loss_final_all))
    np.save('data/kernel_test_loss_final_all.npy', np.array(kernel_test_loss_final_all))
