In [175]:
using MLDatasets, Flux, LinearAlgebra, ProgressMeter, Zygote, ReverseDiff
include("DenseNTK.jl"); include("FastNTK.jl")

function load_MNIST()
    """
    Loading the MNIST dataset.
    10 classes of digits from 0 to 9,
    each with 28x28 pixel dimensions.
    X: Grayscale vector, Y: Correct label.
    """

    X_training, Y_training = MNIST(split = :train)[:]
    X_testing, Y_testing = MNIST(split = :test)[:]
    X_training = Flux.flatten(X_training)
    X_testing = Flux.flatten(X_testing)
    Y_training = Flux.onehotbatch(Y_training, 0:9)
    Y_testing = Flux.onehotbatch(Y_testing, 0:9)
    return X_training, Y_training, X_testing, Y_testing
end

### MODELS
DenseNTKmodel = Chain(
  DenseNTK(28 * 28, 16, sigmoid),        # 784 x 16 + 16 = 12560 parameters
  DenseNTK(16, 16, sigmoid),             #  16 x 16 + 16 =   272 parameters
  DenseNTK(16, 16, sigmoid),             #  16 x 16 + 16 =   272 parameters
  DenseNTK(16, 10, sigmoid)              #  16 x 10 + 10 =   170 parameters
)  

model = Chain(
  Dense(28 * 28, 16, sigmoid),        # 784 x 16 + 16 = 12560 parameters
  Dense(16, 16, sigmoid),             #  16 x 16 + 16 =   272 parameters
  Dense(16, 16, sigmoid),             #  16 x 16 + 16 =   272 parameters
  Dense(16, 10, sigmoid)              #  16 x 10 + 10 =   170 parameters
)   

### DATA 
N=10
x = load_MNIST()[1][:,1:N];

Using only Zygote's jacobian

In [135]:
zygote = Zygote.jacobian(() -> model(x), Flux.params(model))
zygote = hcat([(grad) for grad in zygote]...);

display(zygote)

println("Jacobian computed with $N datapoints and $(length(Flux.destructure(DenseNTKmodel)[1])) parameters.")

100×13274 Matrix{Float32}:
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  …  0.0       0.0       0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0       0.0       0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0       0.0       0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0       0.0       0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0       0.0       0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  …  0.0       0.0       0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0       0.0       0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.246681  0.0       0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0       0.244653  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0       0.0       0.223027
 ⋮                        ⋮                   ⋱                      
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0       0.0       0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0       0.0       0.0
 0.0  0.0  0.0  0.0  0

Jacobian computed with 10 datapoints and 13274 parameters.


Zygote's jacobian splitted

In [18]:
function split(model, x, show_progress=false)
    N = check_dim(x)
    m = length(model(x[:,1]))
    k = length(Flux.destructure(model)[1])
    lastbias = length(Flux.params(model)[length(Flux.params(model))])
    
    Df = zeros(N*m, k)

    if show_progress
        prog = Progress(N, 1)
    end

    for i = 1:N
        D = Zygote.jacobian(() -> model(x[:,i]), Flux.params(model))
        D = hcat([(grad) for grad in D]...) # is m*k matrix

        Df[(i-1)*m+1:i*m, :] .= D  

        if show_progress
            next!(prog)  # Update progress meter
        end
    end
    return Df
end

split_zygote = split(model,x)
println("Jacobian computed with $N datapoints and $(length(Flux.destructure(DenseNTKmodel)[1])) parameters.")

Jacobian computed with 1 datapoints and 13274 parameters.


ReverseDiff: Jacobian tape

In [178]:
params, restruct = Flux.destructure(model)

function m(x,p::Vector)
    mod = restruct(p)
    y = mod(x)
end

D= Zygote.jacobian(params) do params
    m(x,params)
end
fun = hcat(D...)

100×13274 Matrix{Float32}:
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  …  0.0       0.0       0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0       0.0       0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0       0.0       0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0       0.0       0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0       0.0       0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  …  0.0       0.0       0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0       0.0       0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.227642  0.0       0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0       0.212916  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0       0.0       0.249921
 ⋮                        ⋮                   ⋱                      
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0       0.0       0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0       0.0       0.0
 0.0  0.0  0.0  0.0  0

In [80]:
a = round.(fun, digits=10) == round.(split_zygote, digits=10)
b = round.(fun, digits=16) == round.(zygote, digits=16)
c = fun == zygote

display(a)
display(b)
display(c)


false

true

true

Þessar niðurstöður sýna okkur að split aðferðin skilar ónákvæmnari gildum in ult_kernel, sem búast mátti við.
<br>
Athugum nú hvort diffrunin sé möguleg með ReverseDiff

In [179]:
## JACOBIAN ReverseDiff
J = ReverseDiff.jacobian(m, (x[:,10], params))

J[2] # Þetta er jacobian-inn

10×13274 Matrix{Float32}:
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  …  0.0       0.0       0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0       0.0       0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0       0.0       0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0       0.0       0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0       0.0       0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  …  0.0       0.0       0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0       0.0       0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.227722  0.0       0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0       0.213368  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0       0.0       0.249893

In [89]:
J[2] ≈ fun

true

In [180]:
tape = ReverseDiff.JacobianTape(m, (x[:,1],params))
comp_tape = ReverseDiff.compile(tape)

Jacobian_result = ReverseDiff.jacobian!(comp_tape, (x[:,10], params))[2]

10×13274 Matrix{Float32}:
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  …  0.0       0.0       0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0       0.0       0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0       0.0       0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0       0.0       0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0       0.0       0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  …  0.0       0.0       0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0       0.0       0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.227729  0.0       0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0       0.213015  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0       0.0       0.249914

In [182]:
# Zygote
@time begin
    zygote = Zygote.jacobian(() -> model(x[:,10]), Flux.params(model))
    zygote = hcat([(grad) for grad in zygote]...)
end

# ReverseDiff - No tape
@time ReverseDiff.jacobian(m, (x[:,10], params))[2]

# ReverseDiff - Compiled tape
@time ReverseDiff.jacobian!(comp_tape, (x[:,10], params))[2];

  0.686210 seconds (2.19 M allocations: 114.001 MiB, 4.59% gc time, 99.11% compilation time)
  0.007520 seconds (21.40 k allocations: 2.957 MiB)
  0.004461 seconds (50 allocations: 1.545 MiB)
