In [103]:
using MLDatasets, Flux, LinearAlgebra, ProgressMeter, Zygote, ReverseDiff
include("DenseNTK.jl"); include("FastNTK.jl")

function load_MNIST()
    """
    Loading the MNIST dataset.
    10 classes of digits from 0 to 9,
    each with 28x28 pixel dimensions.
    X: Grayscale vector, Y: Correct label.
    """

    X_training, Y_training = MNIST(split = :train)[:]
    X_testing, Y_testing = MNIST(split = :test)[:]
    X_training = Flux.flatten(X_training)
    X_testing = Flux.flatten(X_testing)
    Y_training = Flux.onehotbatch(Y_training, 0:9)
    Y_testing = Flux.onehotbatch(Y_testing, 0:9)
    return X_training, Y_training, X_testing, Y_testing
end

### MODELS
DenseNTKmodel = Chain(
  DenseNTK(28 * 28, 16, sigmoid),        # 784 x 16 + 16 = 12560 parameters
  DenseNTK(16, 16, sigmoid),             #  16 x 16 + 16 =   272 parameters
  DenseNTK(16, 16, sigmoid),             #  16 x 16 + 16 =   272 parameters
  DenseNTK(16, 10, sigmoid)              #  16 x 10 + 10 =   170 parameters
)  

model = Chain(
  Dense(28 * 28, 16, sigmoid),        # 784 x 16 + 16 = 12560 parameters
  Dense(16, 16, sigmoid),             #  16 x 16 + 16 =   272 parameters
  Dense(16, 16, sigmoid),             #  16 x 16 + 16 =   272 parameters
  Dense(16, 10, sigmoid)              #  16 x 10 + 10 =   170 parameters
)   

### DATA 
N=3000
x = load_MNIST()[1][:,1:N];

Using only Zygote's jacobian

In [135]:
zygote = Zygote.jacobian(() -> model(x), Flux.params(model))
zygote = hcat([(grad) for grad in zygote]...);

display(zygote)

println("Jacobian computed with $N datapoints and $(length(Flux.destructure(DenseNTKmodel)[1])) parameters.")

100×13274 Matrix{Float32}:
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  …  0.0       0.0       0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0       0.0       0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0       0.0       0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0       0.0       0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0       0.0       0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  …  0.0       0.0       0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0       0.0       0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.246681  0.0       0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0       0.244653  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0       0.0       0.223027
 ⋮                        ⋮                   ⋱                      
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0       0.0       0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0       0.0       0.0
 0.0  0.0  0.0  0.0  0

Jacobian computed with 10 datapoints and 13274 parameters.


Zygote's jacobian splitted

In [18]:
function split(model, x, show_progress=false)
    N = check_dim(x)
    m = length(model(x[:,1]))
    k = length(Flux.destructure(model)[1])
    lastbias = length(Flux.params(model)[length(Flux.params(model))])
    
    Df = zeros(N*m, k)

    if show_progress
        prog = Progress(N, 1)
    end

    for i = 1:N
        D = Zygote.jacobian(() -> model(x[:,i]), Flux.params(model))
        D = hcat([(grad) for grad in D]...) # is m*k matrix

        Df[(i-1)*m+1:i*m, :] .= D  

        if show_progress
            next!(prog)  # Update progress meter
        end
    end
    return Df
end

split_zygote = split(model,x)
println("Jacobian computed with $N datapoints and $(length(Flux.destructure(DenseNTKmodel)[1])) parameters.")

Jacobian computed with 1 datapoints and 13274 parameters.


ReverseDiff: Jacobian tape

In [178]:
params, restruct = Flux.destructure(model)

function m(x,p::Vector)
    mod = restruct(p)
    y = mod(x)
end

D= Zygote.jacobian(params) do params
    m(x,params)
end
fun = hcat(D...)

100×13274 Matrix{Float32}:
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  …  0.0       0.0       0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0       0.0       0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0       0.0       0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0       0.0       0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0       0.0       0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  …  0.0       0.0       0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0       0.0       0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.227642  0.0       0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0       0.212916  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0       0.0       0.249921
 ⋮                        ⋮                   ⋱                      
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0       0.0       0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0       0.0       0.0
 0.0  0.0  0.0  0.0  0

In [80]:
a = round.(fun, digits=10) == round.(split_zygote, digits=10)
b = round.(fun, digits=16) == round.(zygote, digits=16)
c = fun == zygote

display(a)
display(b)
display(c)


false

true

true

Þessar niðurstöður sýna okkur að split aðferðin skilar ónákvæmnari gildum in ult_kernel, sem búast mátti við.
<br>
Athugum nú hvort diffrunin sé möguleg með ReverseDiff

In [179]:
## JACOBIAN ReverseDiff
J = ReverseDiff.jacobian(m, (x[:,10], params))

J[2] # Þetta er jacobian-inn

10×13274 Matrix{Float32}:
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  …  0.0       0.0       0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0       0.0       0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0       0.0       0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0       0.0       0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0       0.0       0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  …  0.0       0.0       0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0       0.0       0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.227722  0.0       0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0       0.213368  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0       0.0       0.249893

In [89]:
J[2] ≈ fun

true

In [180]:
tape = ReverseDiff.JacobianTape(m, (x[:,1],params))
comp_tape = ReverseDiff.compile(tape)

Jacobian_result = ReverseDiff.jacobian!(comp_tape, (x[:,10], params))[2]

10×13274 Matrix{Float32}:
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  …  0.0       0.0       0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0       0.0       0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0       0.0       0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0       0.0       0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0       0.0       0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  …  0.0       0.0       0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0       0.0       0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.227729  0.0       0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0       0.213015  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0       0.0       0.249914

In [182]:
# Zygote
@time begin
    zygote = Zygote.jacobian(() -> model(x[:,10]), Flux.params(model))
    zygote = hcat([(grad) for grad in zygote]...)
end

# ReverseDiff - No tape
@time ReverseDiff.jacobian(m, (x[:,10], params))[2]

# ReverseDiff - Compiled tape
@time ReverseDiff.jacobian!(comp_tape, (x[:,10], params))[2];

  0.686210 seconds (2.19 M allocations: 114.001 MiB, 4.59% gc time, 99.11% compilation time)
  0.007520 seconds (21.40 k allocations: 2.957 MiB)
  0.004461 seconds (50 allocations: 1.545 MiB)


In [109]:
### TRACKER
using Tracker

function remove_last_bias(model, Jacobian)
    """Removes last bias of model in jacobian, because of 'frozen parameter' """
    lastbias = length(Flux.params(model)[length(Flux.params(model))])
    Jacobian = Jacobian[:, 1:end-lastbias]
    return Jacobian
end

params, restruct = Flux.destructure(DenseNTKmodel)

g = (p) -> begin
    mod = restruct(p)
    y = mod(x[:,1])
end

l = length(g(params))
D = zeros(N*l,length(params))

@time begin
@showprogress for i = 1:size(x)[2]
    h = (p) -> begin
        mod = restruct(p)
        y = mod(x[:,i])
    end
    d = Tracker.data(Tracker.jacobian(h, params))
    D[(i-1)*l+1:i*l, :] .= d
end
end # time ends

D = remove_last_bias(DenseNTKmodel, D)

[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:30[39m[K


 30.602500 seconds (24.61 M allocations: 37.503 GiB, 9.70% gc time, 0.10% compilation time)


30000×13264 Matrix{Float64}:
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  …  0.0        0.0        0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0        0.0        0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0        0.0        0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0        0.0        0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0        0.0        0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  …  0.0        0.0        0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0        0.0        0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0519299  0.0        0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0        0.0415952  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0        0.0        0.0537601
 ⋮                        ⋮              ⋱                        
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0        0.0        0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0        0.0        0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0        0.0

In [115]:
K = zeros(N*l, N*l)

@showprogress for i = 1:size(D)[1]
    for j = 1:size(D')[2]
        K[i,j] = dot(D[i,:], D'[:,j])
    end
end

In [110]:
D*D'

30000×30000 Matrix{Float64}:
  0.0195356    -0.000178003   0.000173649  …  -7.25783e-5   -0.000246979
 -0.000178003   0.0215972    -0.00033316      -0.000339688  -0.000607827
  0.000173649  -0.00033316    0.0190076       -0.000214725  -0.000168916
  0.000943974  -0.000761443  -0.00010866      -0.000158166   0.00088693
  0.000142383   0.000395563   1.55923e-5      -0.000264434  -0.000334077
 -0.000198168   0.000334908  -0.000474511  …   0.000275503   2.59237e-5
 -0.000108574  -0.000131702  -0.000175174      0.000129531   7.40894e-5
 -0.000670298  -0.000255474   0.00129265      -0.000626777   0.000156477
 -7.20978e-5   -0.000339085  -0.000216417      0.0128442    -9.66898e-5
 -0.00024865   -0.000613274  -0.000168246     -9.65809e-5    0.0218568
  ⋮                                        ⋱                
 -0.000176645   0.0215842    -0.000332527     -0.000340413  -0.000602777
  0.00017423   -0.000332534   0.0189959       -0.000213173  -0.000169512
  0.000945519  -0.000758957  -0.00010928