In [None]:
using Flux, LinearAlgebra, IterTools
using ProgressMeter

include("DenseNTK.jl")

In [None]:
function check_dim(x)
    """This function checks the appropriate  dimensions of input data"""
    if isa(x, Matrix)
        return size(x, 2)  # Returns the number of columns (width) of the matrix
    elseif isa(x, Vector)
        return 1  # Return 1 if it's a column vector
    else
        type = typeof(x)
        error("Input data type: $type is neither a matrix or column vector")
    end
end

In [None]:
function jac(model, x, f,param)
    """Gets the jacobian of a specific parameter"""
    jaco(f) = Flux.jacobian(() -> model(x)[f],Flux.params(model))
    return jaco(f)[Flux.params(model)[param]]
end


In [None]:

function Df(model, x)
    # x: single datapoint
    m = length(model(x))

    # Total amount of θ exluding final bias
    total_amount_of_θ = sum(length, Flux.params(model))  - length(Flux.params(model)[length(Flux.params(model))])

    # Skilgreini empty jacobian matrix
    Jacob = zeros(total_amount_of_θ,m)

    for func_i = 1:m
        current_col = Vector{Float64}(undef, 0) # Preallocate memory
        for param_i = 1:length(Flux.params(model)) - 1 # -1 because we don't want to include the final bias
            jac_vec = jac(model, x, func_i, param_i)[:]
            current_col = vcat(current_col, jac_vec) # Concatenate vectors
        end
        
        for k = 1:total_amount_of_θ 
            Jacob[k, func_i] = current_col[k]
        end
    end
    
    return Jacob # Þetta er Df fylkið í bilblíunni
end


In [None]:


function kernel(model, x, show_progress=false)
    N = check_dim(x)
    m = length(model(x[:,1]))  # Number of functions in the model output
    K = zeros(N*m, N*m)

    if show_progress
        p = Progress(N, 1, "Computing kernel:", 50)
    end

    for i = 1:N
        for j = 1:N
            block = Df(model, x[:,i])' * Df(model, x[:,j])
            K[(i-1)*m+1:i*m, (j-1)*m+1:j*m] .= block
        end
        if show_progress
            next!(p)  # Increment progress meter
        end
    end

    if show_progress
        finish!(p)  # Finish progress meter
    end

    return K
end

In [143]:
Nx=20
a=-1.0; b=1.0

xVec=collect(range(a,stop=b,length=Nx));
yVec=sin.(2*pi*xVec) .+ 0.1*randn(size(xVec)); 

Nh2 = 100 # found via quadratic equation
model = Chain(DenseNTK(1,Nh2,relu), DenseNTK(Nh2,Nh2,relu), DenseNTK(Nh2,20))

Chain(
  DenseNTK(Float32[-0.36745083; -0.28118017; … ; 0.38877422; -0.70553064;;], Float32[0.9655598, -0.6751468, -0.2864891, 0.22647955, 1.9972303, -0.65081954, 0.13882352, -0.6892691, -0.3263925, 0.68589026  …  1.0695404, -0.8641656, -0.9559, -0.12768815, -0.5637661, -1.5699892, -2.1146722, -0.17123432, -0.75673157, -0.9442188], NNlib.relu),  [90m# 200 parameters[39m
  DenseNTK(Float32[0.8233794 0.11223748 … 0.05494005 -1.123476; 0.79262894 -0.2469375 … -0.660826 -0.51515085; … ; -0.85384595 -0.16199361 … -1.0499369 -0.72358584; -0.09184618 1.3262553 … -0.13930495 -0.78035176], Float32[1.6822311, -0.98510396, 1.7433788, 0.22409013, 0.1910342, 1.4367994, -0.49651182, -0.2842329, -0.70850766, 0.4881591  …  -0.7513562, -0.55978525, 0.89840716, -1.908574, 0.76917416, -0.17845866, 0.2186963, -1.9196416, 1.8684497, 0.7008971], NNlib.relu),  [90m# 10_100 parameters[39m
  DenseNTK(Float32[-0.48429808 0.37855625 … -0.46772817 -1.8110397; 0.32301813 -0.6404664 … 2.3067555 0.7864435; … ; -

In [None]:
org_K = kernel(model, hcat(xVec...),true);

In [None]:
D = Flux.jacobian(() -> model(hcat(xVec...)),Flux.params(model))
D = hcat([grad for grad in D]...)
D = D[:,1:end-1]

j = jac(model, hcat(xVec[1]), 2, 4)

In [None]:
x = hcat(xVec...)
model(x)

D = Flux.jacobian(() -> model(x),Flux.params(model))

In [None]:
function fast_K(model, x, show_progress = false)
    N = check_dim(x)
    m = length(model(x[:,1]))  # Number of functions in the model output
    K = zeros(N*m, N*m)

    D = Flux.jacobian(() -> model(x),Flux.params(model))
    D = hcat([grad for grad in D]...)
    D = D[:,1:end-1] # to skip the last bias
    
    k = size(D, 2) # number total parametes

    if show_progress
        p = Progress(N, 1, "Computing kernel:", 50)
    end

    for i = 1:N
        for j = 1:N
            block_i = (i - 1) * m + 1
            block_j = (j - 1) * m + 1
            block = D[block_i:block_i+m-1, :] * transpose(D[block_j:block_j+m-1, :])
            K[block_i:block_i+m-1, block_j:block_j+m-1] = block
        end
        if show_progress
            next!(p)  # Increment progress meter
        end
    end

    if show_progress
        finish!(p)  # Finish progress meter
    end

    return K
end

In [148]:

function fast_multidim_K(model, x, show_progress=false)
    N = check_dim(x)                    # Number of datapoints
    m = length(model(x[:,1]))           # Number of functions in the model output

    Θ = zeros(N*(m*m), N*(m*m))         # Kernel is depecicted in research papers and Wikipedia

    D = Flux.jacobian(() -> model(x),Flux.params(model))
    D = hcat([grad for grad in D]...)
    D = D[:,1:end-1]                    # To skip the last bias

    ∂(f,x) = D[(f-m)+(x*m),:]           # Used in nested for readabiity

    if show_progress
        progress_Θ = Progress(m, 1, "Computing Θ:", 50)
    end

    for k = 1:m
        for l = 1:m
            mini_kernel = zeros(N,N)
            for i = 1:N
                for j = 1:N
                    mini_kernel[i,j] = dot(∂(k,i),∂(l,j))
                end
            end
            # Add mini_kernel to the corresponding portion of Θ
            Θ[(k-1)*N+1:k*N, (l-1)*N+1:l*N] .= mini_kernel
        end
        if show_progress
            next!(progress_Θ)           # Increment progress meter Θ
        end
    end

    return Θ
end

fast_multidim_K (generic function with 2 methods)

In [145]:
BIG_K = fast_multidim_K(model,hcat(xVec...),true)

[32mComputing Θ: 100%|██████████████████████████████████████████████████| Time: 0:02:30[39m[K


8000×8000 Matrix{Float64}:
 3.62445  3.42374  3.32293  3.21078  …  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 3.42374  3.34587  3.25245  3.14608     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 3.32293  3.25245  3.17468  3.07815     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 3.21078  3.14608  3.07815  3.08306     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 3.07181  3.01924  2.96072  2.97493     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 2.97885  2.93286  2.8818   2.90205  …  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 2.86703  2.83036  2.78734  2.81355     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 2.62568  2.59626  2.55936  2.58943     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 2.56018  2.53549  2.5041   2.5369      0.0  0.0  0.0  0.0  0.0  0.0  0.0
 2.50193  2.48092  2.45469  2.49069     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 ⋮                                   ⋱            ⋮                   
 0.0      0.0      0.0      0.0         0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0      0.0      0.0      0.0         0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0      0.0 

In [146]:
eigen(BIG_K).values

8000-element Vector{Float64}:
  0.0
  0.0
  0.0
  0.0
  0.0
  0.0
  0.0
  0.0
  0.0
  0.0
  ⋮
 44.35405435550982
 45.649067733304776
 47.04806805132091
 49.45869665928705
 50.610121204609186
 54.72200771216017
 56.55862287897324
 58.68728418726743
 66.4106631922362

In [None]:
fast_K(model, hcat(xVec...), true)