In [157]:
using MLDatasets, Flux, LinearAlgebra, ProgressMeter, Zygote, ReverseDiff
include("DenseNTK.jl"); include("FastNTKMethods.jl")

function load_MNIST()
    """
    Loading the MNIST dataset.
    10 classes of digits from 0 to 9,
    each with 28x28 pixel dimensions.
    X: Grayscale vector, Y: Correct label.
    """

    X_training, Y_training = MNIST(split = :train)[:]
    X_testing, Y_testing = MNIST(split = :test)[:]
    X_training = Flux.flatten(X_training)
    X_testing = Flux.flatten(X_testing)
    Y_training = Flux.onehotbatch(Y_training, 0:9)
    Y_testing = Flux.onehotbatch(Y_testing, 0:9)
    return X_training, Y_training, X_testing, Y_testing
end

### MODELS
DenseNTKmodel = Chain(
  DenseNTK(28 * 28, 16, sigmoid),        # 784 x 16 + 16 = 12560 parameters
  DenseNTK(16, 16, sigmoid),             #  16 x 16 + 16 =   272 parameters
  DenseNTK(16, 16, sigmoid),             #  16 x 16 + 16 =   272 parameters
  DenseNTK(16, 10, sigmoid)              #  16 x 10 + 10 =   170 parameters
)  

model = Chain(
  Dense(28 * 28, 16, sigmoid),        # 784 x 16 + 16 = 12560 parameters
  Dense(16, 16, sigmoid),             #  16 x 16 + 16 =   272 parameters
  Dense(16, 16, sigmoid),             #  16 x 16 + 16 =   272 parameters
  Dense(16, 10, sigmoid)              #  16 x 10 + 10 =   170 parameters
)   

### DATA 
N=1000
x = load_MNIST()[1][:,1:N];

In [152]:
K1 = kernel(DenseNTKmodel, x, true, 1);
K2 = kernel(DenseNTKmodel, x, true, 2);
K3 = kernel(DenseNTKmodel, x, true, 3);

[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:12[39m[K



ReverseDiff Jacobian: 7.679137 seconds (86.08 k allocations: 2.493 GiB, 1.40% gc time, 1.39% compilation time)
Kernel computation: 7.956876 seconds (3 allocations: 762.940 MiB, 16.17% gc time)


[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:14[39m[K



ReverseDiff Jacobian: 12.059156 seconds (87.48 M allocations: 6.502 GiB, 7.73% gc time, 0.34% compilation time)
Kernel computation: 6.405424 seconds (3 allocations: 762.940 MiB, 0.02% gc time)

Zygote Jacobian: 65.605930 seconds (4.01 M allocations: 160.542 GiB, 18.98% gc time, 1.37% compilation time: 2% of which was recompilation)
Kernel computation: 6.467038 seconds (3 allocations: 762.940 MiB, 1.08% gc time)


In [153]:
E1 = eigen(K1).values
E2 = eigen(K2).values
E3 = eigen(K3).values

println("Eigen values:")
hcat(E1, E2, E3)

Eigen values:


10000×3 Matrix{Float64}:
 -9.3875e-16   -1.21846e-15  -1.53385e-15
 -4.0555e-16   -5.95731e-16  -6.77098e-16
 -3.18102e-16  -5.29475e-16  -6.27329e-16
 -2.40281e-16  -3.56589e-16  -3.58033e-16
 -1.93665e-16  -2.69777e-16  -2.6207e-16
 -1.27721e-16  -1.27346e-16  -1.32238e-16
 -1.20999e-16  -1.16049e-16  -1.08502e-16
 -1.12735e-16  -9.64305e-17  -8.82988e-17
 -6.45881e-17  -7.38044e-17  -8.39919e-17
 -5.92228e-17  -5.85043e-17  -6.25166e-17
  ⋮                          
 11.6726       11.6654       11.6654
 12.4809       12.449        12.449
 14.5037       14.5058       14.5058
 15.9846       15.9985       15.9985
 17.3659       17.4003       17.4003
 19.8271       19.8216       19.8216
 20.953        20.9799       20.9799
 21.7508       21.7453       21.7453
 25.7531       25.7823       25.7823

In [154]:
#display(K1); display(K2); display(K3)

<h1>5% of MNIST</h1>

In [158]:
### DATA 
N=1
x5 = load_MNIST()[1][:,1:N];

In [159]:
K1 = kernel(DenseNTKmodel, x5, true, 1);

ErrorException: Input data type: Int64 is neither a matrix or column vector