# Analysis of neural tangent kernel performance

Given the pre-generated neural tangent kernel (NTK) output from the main code (by default in the directory `'./kernel_output'`), we examine the classification performance on the MNIST dataset of the exact, sparsified, and diagonal NTKs. Additionally, for the quantum algorithms of sparsified and diagonal NTKs, the condition number and the number of measurements required for post-selection/readout are verified to be bounded by $O(\log n)$.

In [None]:
import numpy as np
import glob

from IPython.display import set_matplotlib_formats
set_matplotlib_formats('pdf', 'svg')
import matplotlib
import seaborn as sns
  
sns.set(font_scale=1.3)
sns.set_style("whitegrid", {"axes.facecolor": ".97"})

import matplotlib.pyplot as plt

## Sparsity pattern

First, a sparsity pattern is constructed in $\tilde O(n)$ time. In the proposed quantum algorithm, this is performed once when the data is stored in a binary QRAM data structure (also in $\tilde O(n)$ time). Given a sparsity pattern with at most $s = O(\log n)$ nonzero elements in any row or column, multiple neural networks (of different architectures) can be efficiently trained in logarithmic time using the same sparsity pattern.

In [None]:
def get_target_sparsity(m):
    """
    Get expected matrix sparsity, chosen to be O(log n).
    """
    return np.log(m.shape[0])

def block_diagonal(m):
    """
    Prepare a block diagonal matrix [[1, 0], [0, 1]] corresponding to the two data classes
    in the NTK matrix.
    """
    class_size = m.shape[0]//2
    ones_class = np.ones((class_size, class_size))
    zeros_class = np.zeros((class_size, class_size))
    class_0 = np.block([[ones_class, zeros_class], [zeros_class, zeros_class]])
    class_1 = np.block([[zeros_class, zeros_class], [zeros_class, ones_class]])
    return class_0, class_1

def get_sparsity_pattern(m):
    """
    Prepare in O(n log n) time a sparsity pattern over the n x n matrix with a
    pseudorandom generator.
    """
    target_sparsity = get_target_sparsity(m)
    
    # procedure produces an equivalent distribution of 1s and 0s as sampling individual
    # matrix elements i.i.d. from binomial distribution
    
    # since we'll take half of the generated indices, we set the probability of a nonzero
    # element to be double the target sparsity
    p_one = min(2*target_sparsity/m.shape[0], 1.0)
    
    # for each row, sample the binomial distribution to get the number of nonzero indices
        # matches in expectation get_target_sparsity(m), i.e. O(log n)
    # reference the upper triangular indices according to the lower triangular indices
        # can be done efficiently by mapping indices instead of copying matrix elements

    one_filter = np.zeros(m.shape)
    for i in range(m.shape[0]):
        # find O(log n) indices
        num_nonzero = np.random.randint(m.shape[0],
                                        size=np.random.binomial(m.shape[0], p_one))
        one_filter[i][num_nonzero] = 1
    one_filter = np.tril(one_filter) + np.tril(one_filter, -1).T
    
    # set all NTK matrix elements from opposite classes to be zero
    # since the NTK is larger for more similar data examples, this biases the sparse
    # matrix towards selecting more important examples
    class_0, class_1 = block_diagonal(m)
    one_filter = one_filter * (class_0 + class_1)
    
    # make sure the diagonal is ones
    np.fill_diagonal(one_filter, 1)
    
    return one_filter

def sparsify_unbiased(m, sparsity_pattern):
    """
    Sparsify NTK matrix `m` using a given sparsity pattern.
    Used for the fully-connected network.
    """
    return m * sparsity_pattern

def sparsify_biased(m, sparsity_pattern, t0, t1):
    """
    Sparsify NTK matrix `m` using a given sparsity pattern, then additionally sparsify by
    setting elements below `t0` and `t1` in classes 0 and 1 respectively to 0.
    Used for the convolutional network.
    """
    class_0, class_1 = block_diagonal(m)
    one_filter = sparsity_pattern * ((m > t0) * class_0 + (m > t1) * class_1)
    np.fill_diagonal(one_filter, 1)
    
    kernel_train_sparse = m * one_filter
    
    # we expect a factor of ~target_sparsity by Gershgorin's theorem
    # empirically, the well-conditioning of the kernel makes it scale better than this
    f = 0.76 * get_target_sparsity(m)**0.9
    conditioning = f * np.diag(kernel_train_sparse)*np.eye(kernel_train_sparse.shape[0])
    kernel_train_conditioned = kernel_train_sparse + conditioning
    return kernel_train_conditioned

def compute_class_percentiles(m, percentile):
    """
    Compute the truncation thresholds for `sparsify_biased`. This is evaluated over a
    small subset (n = 16) of the training set to efficiently bias the sparsification
    towards large off-diagonal elements.
    """
    class_size = m.shape[0]//2
    ones_class = np.ones((class_size, class_size))
    zeros_class = np.zeros((class_size, class_size))
    class_0 = np.block([[ones_class - np.eye(class_size), zeros_class],
                        [zeros_class, zeros_class]])
    class_1 = np.block([[zeros_class, zeros_class],
                        [zeros_class, ones_class - np.eye(class_size)]])
    t0 = np.percentile(np.abs(m * class_0), percentile)
    t1 = np.percentile(np.abs(m * class_1), percentile)
    return t0, t1

def get_sparsity(m):
    """
    Get maximum number of nonzero elements in any row or column.
    """
    return np.amax(np.sum(m != 0, axis=0))

We verify that the sparsity pattern does indeed scale like $O(\log n)$.

In [None]:
Ns = [16, 32, 64, 128, 256, 512]
sparsity_trials = 100
sparsities = np.zeros(len(Ns))
sparsities_std = np.zeros(len(Ns))
for i in range(len(Ns)):
    N = Ns[i]
    sparsities_N = []
    for t in range(sparsity_trials):
        sparsity_pattern = get_sparsity_pattern(np.zeros((N, N)))
        s = get_sparsity(sparsity_pattern)
        sparsities_N.append(s)
    sparsities[i] = np.mean(sparsities_N)
    sparsities_std[i] = np.std(sparsities_N)/np.sqrt(len(sparsities_N))

plt.figure(figsize=(5, 4))
plt.errorbar(Ns, sparsities, yerr=2*sparsities_std, fmt='o', c='C1')
plt.xlabel('Training set size')
plt.ylabel('Sparsity')
plt.xscale('log')
plt.xticks(Ns)
plt.minorticks_off()
plt.gca().get_xaxis().set_major_formatter(matplotlib.ticker.ScalarFormatter())
plt.tight_layout()
plt.show()

## Neural network performance

Four quantities characterize the infinite-width neural network and its sparsified and diagonal approximations:
* Binary classification accuracy: all three networks are evaluated on a balanced sample of the MNIST test set (separate from the training set).
* Condition number: to invert the sparsified NTK $\tilde K$ efficiently with a quantum linear systems algorithm, the condition number $\kappa(\tilde K)$ (defined to be the ratio of the largest to smallest singular values) must be bounded by $O(\log n)$.
* Post-selection measurements: to prepare the quantum state $|k_*\rangle = \frac{1}{\sqrt{P}} \sum_{i=0}^{n-1} k_i |i\rangle$ of the NTK evaluated between test data $\mathbf x_*$ and the training data $\{\mathbf x_i\}$, we require $O(1/P)$ measurements for $P = \sum_i k_i^2$. Here, $k_i$ corresponds to the kernel $k(\mathbf x_*, \mathbf x_i)$ normalized and clipped to lie within $-1 \leq k_i \leq 1$. To efficiently prepare the state, the number of measurements must be bounded by $O(\log n)$.
* Readout measurements: to perform the final readout, we estimate the sign of state overlap $o = \langle k_* | y \rangle$ (for the diagonal approximation) or $o = \langle k_* | \tilde K^{-1} | y\rangle$ (for the sparsified approximation). This requires $O(1/|o|^2)$ measurements, which must be bounded by $O(\log n)$ for efficient readout.

In [None]:
def classify(ntk_mean):
    """
    Classify raw output of the NTK on the test dataset, assuming the test data is sampled
    i.i.d. from the underlying data distribution (i.e. balanced).
    """
    thresh = np.median(ntk_mean)
    out = np.sign(ntk_mean - thresh)
    return out

def get_file_prefix(fp, seed, N, trial):
    """
    NTK output filename
    """
    return fp + '_seed' + str(seed) + '_data' + str(N) + '_trial' + str(trial) + '_'

def analyze(file_prefix, Ns, sparsify_fnc, sparsify_args=(), sparsity_bootstraps=3,
               plot_margin=0):
    """
    Plot the accuracy, condition number, number of measurements for post-selection, and
    number of measurements for readout.
    """
    Ns = np.array(Ns)
    accs_mean = []
    accs_std = []
    measurements = []
    post_selections = []
    measurements_std = []
    post_selections_std = []
    
    all_kappas = []
    for n_ind in range(len(Ns)):
        N = Ns[n_ind]
        
        # load data
        prefix = get_file_prefix(file_prefix, '*', N, '*')
        suffixes = ['kernel_train.npy', 'kernel_test.npy', 'kernel_test_normalized.npy',
                    'train_label.npy', 'test_label.npy']
        files = []
        for suffix in suffixes:
            files.append(sorted(glob.glob(prefix + '*' + suffix)))
        
        all_dense = []
        all_sparse = []
        all_identity = []
        all_scale = []
        
        trial_p = []
        trial_overlaps_diag = []
        trial_overlaps_sparse = []
        kappas = []
        
        for i in range(len(files[0])):
            # load files
            kernel_train = np.load(files[0][i])
            kernel_test = np.load(files[1][i])
            kernel_test_normalized = np.load(files[2][i])
            train_label = np.load(files[3][i])
            test_label = np.load(files[4][i])
            
            # bootstrap over different sparsity patterns
            for s in range(sparsity_bootstraps):
                # randomize sparsity pattern
                sparsity_pattern = get_sparsity_pattern(kernel_train)
                
                # sparsify kernel
                kernel_train_sparse = sparsify_fnc(kernel_train, sparsity_pattern,
                                                   *sparsify_args)
                kernel_train_identity = np.diag(kernel_train)*np.eye(kernel_train.shape[0])
                
                # calculate condition number
                eigs = np.linalg.eigvals(kernel_train_sparse)
                kappa = np.amax(np.abs(eigs))/np.amin(np.abs(eigs))
                kappas.append(kappa)
                
                # solve A^{-1}y for A being the exact NTK, sparsified NTK, and diagonal NTK
                inv_y_dense = np.linalg.inv(kernel_train) @ train_label
                inv_y_dense /= np.sqrt(np.sum(inv_y_dense**2))
                inv_y_sparse = np.linalg.inv(kernel_train_sparse) @ train_label
                inv_y_sparse /= np.sqrt(np.sum(inv_y_sparse**2))
                inv_y_diag = np.linalg.inv(kernel_train_identity) @ train_label
                inv_y_diag /= np.sqrt(np.sum(inv_y_diag**2))
                
                # prepare |k_*> state
                ki = kernel_test_normalized / np.amax(np.abs(kernel_test_normalized))
                p = np.sum(ki**2, axis=1)
                ki = ki / np.sqrt(p[:, np.newaxis])
                
                # prepare |y> state
                ny = len(train_label)
                y = train_label / np.sqrt(ny)

                trial_p.append(p)  # for post-selection measurements
                trial_overlaps_diag.append(ki @ y)  # <k_*|y>
                trial_overlaps_sparse.append(ki @ inv_y_sparse)  # <k_*|\tilde K^{-1}|y>

                # classify with the exact, sparsified, and diagonal NTKs
                mean_dense = kernel_test @ inv_y_dense
                mean_sparse = kernel_test_normalized @ inv_y_sparse
                mean_identity = kernel_test_normalized @ inv_y_diag
                correct_dense = classify(mean_dense) == test_label
                correct_sparse = classify(mean_sparse) == test_label
                correct_identity = classify(mean_identity) == test_label

                all_dense = np.concatenate((all_dense, correct_dense))
                all_sparse = np.concatenate((all_sparse, correct_sparse))
                all_identity = np.concatenate((all_identity, correct_identity))
                all_scale.append([trial_p, trial_overlaps_diag, trial_overlaps_sparse])
        
        # compute the mean and standard deviation of all quantities
        
        all_out = [all_dense, all_sparse, all_identity]
        accs_mean_s = []
        accs_std_s = []
        for i in range(len(all_out)):
            correct = all_out[i]
            accs_mean_s.append(np.mean(correct))
            accs_std_s.append(np.std(correct)/np.sqrt(len(correct)))
        accs_mean.append(accs_mean_s)
        accs_std.append(accs_std_s)
        
        scale = np.concatenate(all_scale, axis=1)
        p = scale[0, :, :].flatten()
        post_measurements = N/p
        post_selections.append(np.median(post_measurements))
        bootstraps = 5 # Poisson bootstrapping
        medians = np.zeros(bootstraps)
        for b in range(bootstraps):
            r = np.random.poisson(size=post_measurements.shape)
            pm = r * post_measurements
            medians[b] = np.median(pm)
        post_selections_std.append(np.std(medians)/np.sqrt(bootstraps))
        
        overlaps = scale[1:, :, :].reshape(2, -1)
        # enough measurements for stdev to be O(overlap)
        these_measurements = 1/overlaps**2 - 1
        measurements.append(np.median(these_measurements, axis=1))
        
        bootstraps = 5 # Poisson bootstrapping
        medians = np.zeros((bootstraps, 2))
        for b in range(bootstraps):
            r = np.random.poisson(size=these_measurements.shape)
            pm = r * these_measurements
            medians[b] = np.median(pm, axis=1)
        measurements_std.append(np.std(medians, axis=0)/np.sqrt(bootstraps))
        all_kappas.append(kappas)
        
    accs_mean = np.array(accs_mean)
    accs_std = np.array(accs_std)
    
    post_selections = (np.array(post_selections), np.array(post_selections_std))
    measurements = (np.array(measurements), np.array(measurements_std))
    
    kappa = []
    kappa_std = []
    for row in all_kappas:
        kappa.append(np.mean(row))
        kappa_std.append(np.std(row)/np.sqrt(len(row)))
    kappa = np.array(kappa)
    kappa_std = np.array(kappa_std)
    
    # plot everything
    
    plt.figure(figsize=(5, 4))
    plt.errorbar(Ns - Ns*plot_margin, accs_mean[:, 0], yerr=2*accs_std[:, 0],
                 label='Exact NTK', fmt='o')
    plt.errorbar(Ns, accs_mean[:, 1], yerr=2*accs_std[:, 1], label='Sparse NTK', fmt='o')
    plt.errorbar(Ns + Ns*plot_margin, accs_mean[:, 2], yerr=2*accs_std[:, 2],
                 label='Diagonal NTK', fmt='o')
    plt.xlabel('Training set size')
    plt.ylabel('Accuracy')
    plt.xscale('log')
    plt.xticks(Ns)
    plt.minorticks_off()
    plt.gca().get_xaxis().set_major_formatter(matplotlib.ticker.ScalarFormatter())
    plt.legend(loc='lower right')
    plt.tight_layout()
    plt.show()
    
    plt.figure(figsize=(5, 4))
    plt.errorbar(Ns, kappa, yerr=2*kappa_std, fmt='o', c='C1')
    plt.xlabel('Training set size')
    plt.ylabel('Condition number')
    plt.xscale('log')
    plt.xticks(Ns)
    plt.minorticks_off()
    plt.gca().get_xaxis().set_major_formatter(matplotlib.ticker.ScalarFormatter())
    plt.gca().get_yaxis().set_major_formatter(matplotlib.ticker.ScalarFormatter(
                                                                         useOffset=False))
    plt.tight_layout()
    plt.show()
    
    plt.figure(figsize=(5, 4))
    plt.errorbar(Ns, post_selections[0], yerr=2*post_selections[1], fmt='o')
    plt.xlabel('Training set size')
    plt.ylabel('Measurements (post-selection)')
    plt.xscale('log')
    plt.xticks(Ns)
    plt.minorticks_off()
    plt.gca().get_xaxis().set_major_formatter(matplotlib.ticker.ScalarFormatter())
    plt.tight_layout()
    plt.show()
    
    plt.figure(figsize=(5, 4))
    plt.errorbar(Ns - Ns*plot_margin/2, measurements[0][:, 1],
                 yerr=2*measurements[1][:, 1], label='Sparse NTK', c='C1', fmt='o')
    plt.errorbar(Ns + Ns*plot_margin/2, measurements[0][:, 0],
                 yerr=2*measurements[1][:, 0], label='Diagonal NTK', c='C2', fmt='o')
    plt.xlabel('Training set size')
    plt.ylabel('Measurements (readout)')
    plt.xscale('log')
    plt.xticks(Ns)
    plt.minorticks_off()
    plt.gca().get_xaxis().set_major_formatter(matplotlib.ticker.ScalarFormatter())
    plt.legend()
    plt.tight_layout()
    plt.show()

Plot the results for the fully-connected neural network.

In [None]:
analyze('kernel_output/fully-connected', Ns, sparsify_unbiased, plot_margin=1/8)

Estimate the appropriate normalization threshold for preparing $|k_*\rangle$ based on a small subset ($n=16$) of the training set, and then plot the results for the convolutional neural network.

In [None]:
fp = 'kernel_output/convolutional'
base_n = 16
base_ntk = np.load(sorted(glob.glob(get_file_prefix(fp, '*', base_n, '*') + 'kernel_train.npy'))[0])
sparsify_args = compute_class_percentiles(base_ntk, 90)
analyze(fp, Ns, sparsify_biased, sparsify_args=sparsify_args, plot_margin=1/8)