### Comparing kernel matrices with different approaches

In [1]:
using Flux,LinearAlgebra,CairoMakie

In [16]:
include("DenseNTK.jl")
include("normNTK.jl")

kernel (generic function with 1 method)

We start by making a model using DenseNTK that works like Dense in Flux but is more general in a sense of calculating the kernel


In [17]:
Nh = 100_000
model = Chain(DenseNTK(1=>Nh,tanh),DenseNTK(Nh=>1))|>f64
θ = Flux.params(model)
Nx = 21
xa = -1.0
xb = 1.0
xR = range(xa,stop=xb,length=Nx) 
x = hcat(xR...)


1×21 Matrix{Float64}:
 -1.0  -0.9  -0.8  -0.7  -0.6  -0.5  …  0.4  0.5  0.6  0.7  0.8  0.9  1.0

In [18]:

K_i = zeros(Float64,3*Nh+1,Nx)
for i = 1:Nx
    ∇_SIE = Flux.gradient(()-> model([x[i]])[1],θ)
    K_i[1:Nh,i] = ∇_SIE[θ[1]][:]
    K_i[Nh+1:2*Nh,i] = ∇_SIE[θ[2]][:]
    K_i[2*Nh+1:3*Nh,i] = ∇_SIE[θ[3]][:]
    K_i[3*Nh+1:3*Nh+1,i] = ∇_SIE[θ[4]][:] 
end

Kernel_SIE = K_i[1:3*Nh,:]'*K_i[1:3*Nh,:]
# eigen(Kernel_SIE).values

21×21 Matrix{Float64}:
  1.31778     1.28648     1.24904    …  -0.254241   -0.328738   -0.396332
  1.28648     1.26136     1.23053       -0.191065   -0.263319   -0.328738
  1.24904     1.23053     1.20689       -0.121153   -0.191065   -0.254241
  1.20434     1.19286     1.17695       -0.0439198  -0.111384   -0.172249
  1.15113     1.14703     1.13935        0.0411034  -0.0237888  -0.0822643
  1.08812     1.09164     1.09258    …   0.134139    0.0719842   0.0160084
  1.01407     1.02529     1.03506        0.234975    0.1758      0.122495
  0.928078    0.946858    0.965415       0.342699    0.286877    0.236522
  0.829931    0.85588     0.882872       0.455407    0.403501    0.356529
  0.720614    0.753085    0.787816       0.569996    0.522795    0.479828
  ⋮                                  ⋱                           ⋮
  0.356529    0.403501    0.455407       0.882872    0.85588     0.829931
  0.236522    0.286877    0.342699       0.965415    0.946858    0.928078
  0.122495    0.1758

In [19]:
K = kernel(model,x)
# round.(K[:,:],digits=13)==round.(Kernel_SIE[:,:],digits=13)



21×21 Matrix{Float64}:
  1.31778     1.28648     1.24904    …  -0.254241   -0.328738   -0.396332
  1.28648     1.26136     1.23053       -0.191065   -0.263319   -0.328738
  1.24904     1.23053     1.20689       -0.121153   -0.191065   -0.254241
  1.20434     1.19286     1.17695       -0.0439198  -0.111384   -0.172249
  1.15113     1.14703     1.13935        0.0411034  -0.0237888  -0.0822643
  1.08812     1.09164     1.09258    …   0.134139    0.0719842   0.0160084
  1.01407     1.02529     1.03506        0.234975    0.1758      0.122495
  0.928078    0.946858    0.965415       0.342699    0.286877    0.236522
  0.829931    0.85588     0.882872       0.455407    0.403501    0.356529
  0.720614    0.753085    0.787816       0.569996    0.522795    0.479828
  ⋮                                  ⋱                           ⋮
  0.356529    0.403501    0.455407       0.882872    0.85588     0.829931
  0.236522    0.286877    0.342699       0.965415    0.946858    0.928078
  0.122495    0.1758

In [24]:
eigen(Kernel_SIE).values
λ = eigen(K).values
λ = abs.(λ)
λ = sort!(λ)

21-element Vector{Float64}:
  5.175604272701515e-16
  6.484742651941025e-16
  7.36829962326064e-14
  2.2306540355486444e-13
  3.296207572712412e-12
  1.2604049983123277e-11
  1.2399689247478252e-10
  4.970631423631829e-10
  4.411356107667226e-9
  1.8167305863891508e-8
  ⋮
  5.200757370384124e-6
  3.592398810475665e-5
  0.0002200356002425372
  0.0016066144796224458
  0.00882438435769005
  0.08066577108622834
  0.3592886645227367
  8.229030212011496
 14.655461425843749

In [36]:
fig = Figure()
ax1 = Axis(fig[1,1],yscale=:log10)
CairoMakie.scatter!(ax1,λ)
fig

MethodError: MethodError: no method matching defaultlimits(::Symbol)

Closest candidates are:
  defaultlimits(!Matched::Tuple{Real, Real}, !Matched::Any)
   @ Makie C:\Users\trist\.julia\packages\Makie\RgxaV\src\makielayout\blocks\axis.jl:1429
  defaultlimits(!Matched::NTuple{4, Real}, !Matched::Any, !Matched::Any)
   @ Makie C:\Users\trist\.julia\packages\Makie\RgxaV\src\makielayout\blocks\axis.jl:1416
  defaultlimits(!Matched::Tuple{Real, Nothing}, !Matched::Any)
   @ Makie C:\Users\trist\.julia\packages\Makie\RgxaV\src\makielayout\blocks\axis.jl:1430
  ...
