# Navigating Through Loss Space

In [None]:
using Flux, Flux.Data.MNIST, Statistics
using Flux: onehotbatch, onecold, crossentropy, throttle
using Base.Iterators: repeated, partition
using PlotlyJS
using ProgressMeter
using LinearAlgebra

In [2]:
imgs = MNIST.images()
labels = onehotbatch(MNIST.labels(), 0:9)
train = [(cat(float.(imgs[i])..., dims = 4), labels[:,i])
         for i in partition(1:60_000, 1000)]
tX = cat(float.(MNIST.images(:test)[1:1000])..., dims = 4)
tY = onehotbatch(MNIST.labels(:test)[1:1000], 0:9)

m = Chain(
  Conv((2,2), 1=>16, relu),
  x -> maxpool(x, (2,2)),
  Conv((2,2), 16=>8, relu),
  x -> maxpool(x, (2,2)),
  x -> reshape(x, :, size(x, 4)),
  Dense(288, 10), softmax)

loss(x, y) = crossentropy(m(x), y)
accuracy(x, y) = mean(onecold(m(x)) .== onecold(y))
evalcb = throttle(() -> @show(accuracy(tX, tY)), 5)

opt = ADAM(params(m))

#43 (generic function with 1 method)

In [12]:
@showprogress for i in 1:10
    Flux.train!(loss, train, opt, cb = evalcb)
end

accuracy(tX, tY) = 0.737
accuracy(tX, tY) = 0.75
accuracy(tX, tY) = 0.753
accuracy(tX, tY) = 0.768
accuracy(tX, tY) = 0.763
accuracy(tX, tY) = 0.771
accuracy(tX, tY) = 0.784


[32mProgress:  10%|████                                     |  ETA: 0:07:51[39m

accuracy(tX, tY) = 0.779
accuracy(tX, tY) = 0.784
accuracy(tX, tY) = 0.791
accuracy(tX, tY) = 0.788
accuracy(tX, tY) = 0.785
accuracy(tX, tY) = 0.795
accuracy(tX, tY) = 0.796


[32mProgress:  20%|████████                                 |  ETA: 0:06:58[39m

accuracy(tX, tY) = 0.805
accuracy(tX, tY) = 0.8
accuracy(tX, tY) = 0.808
accuracy(tX, tY) = 0.805
accuracy(tX, tY) = 0.809
accuracy(tX, tY) = 0.813


[32mProgress:  30%|████████████                             |  ETA: 0:06:03[39m

accuracy(tX, tY) = 0.811
accuracy(tX, tY) = 0.82
accuracy(tX, tY) = 0.807
accuracy(tX, tY) = 0.817
accuracy(tX, tY) = 0.813
accuracy(tX, tY) = 0.824
accuracy(tX, tY) = 0.824


[32mProgress:  40%|████████████████                         |  ETA: 0:05:12[39m

accuracy(tX, tY) = 0.821
accuracy(tX, tY) = 0.823
accuracy(tX, tY) = 0.822
accuracy(tX, tY) = 0.829
accuracy(tX, tY) = 0.823
accuracy(tX, tY) = 0.836
accuracy(tX, tY) = 0.828


[32mProgress:  50%|████████████████████                     |  ETA: 0:04:20[39m

accuracy(tX, tY) = 0.833
accuracy(tX, tY) = 0.832
accuracy(tX, tY) = 0.835
accuracy(tX, tY) = 0.84
accuracy(tX, tY) = 0.832
accuracy(tX, tY) = 0.84
accuracy(tX, tY) = 0.838


[32mProgress:  60%|█████████████████████████                |  ETA: 0:03:28[39m

accuracy(tX, tY) = 0.841
accuracy(tX, tY) = 0.838
accuracy(tX, tY) = 0.844
accuracy(tX, tY) = 0.834
accuracy(tX, tY) = 0.845
accuracy(tX, tY) = 0.845


[32mProgress:  70%|█████████████████████████████            |  ETA: 0:02:35[39m

accuracy(tX, tY) = 0.841
accuracy(tX, tY) = 0.84
accuracy(tX, tY) = 0.846
accuracy(tX, tY) = 0.851
accuracy(tX, tY) = 0.845
accuracy(tX, tY) = 0.848
accuracy(tX, tY) = 0.85


[32mProgress:  80%|█████████████████████████████████        |  ETA: 0:01:43[39m

accuracy(tX, tY) = 0.846
accuracy(tX, tY) = 0.846
accuracy(tX, tY) = 0.853
accuracy(tX, tY) = 0.849
accuracy(tX, tY) = 0.851
accuracy(tX, tY) = 0.846
accuracy(tX, tY) = 0.851


[32mProgress:  90%|█████████████████████████████████████    |  ETA: 0:00:52[39m

accuracy(tX, tY) = 0.85
accuracy(tX, tY) = 0.849
accuracy(tX, tY) = 0.854
accuracy(tX, tY) = 0.85
accuracy(tX, tY) = 0.855
accuracy(tX, tY) = 0.858
accuracy(tX, tY) = 0.853


[32mProgress: 100%|█████████████████████████████████████████| Time: 0:08:38[39m


0.852

In [5]:
function perturbweights(m, perturb, t)
    perturbed = [Dict() for i in 1:length(t)]
    for i in 1:length(t)
        for l in [1,3]
            push!(perturbed[i], l => (1-t[i]) .* m.layers[l].weight.data .+ 
                    t[i] * copy(norm(m.layers[l].weight.data))[1] / copy(norm(perturb[l]))[1] .* perturb[l])
            
        end 
        push!(perturbed[i], 6 => (1-t[i]) .* m.layers[6].W.data .+ 
                    t[i] * copy(norm(m.layers[6].W.data))[1] / copy(norm(perturb[6]))[1] .* perturb[6])
    end
    return perturbed
end

function perturb(m)
    pd = Dict()
    for i in [1,3]
        pd[i] = rand(Float32, size(m.layers[i].weight.data))
    end
    pd[6] = rand(Float32,size(m.layers[6].W.data))
    return pd
end

function pw_2d_bilinear(θ0, θ1, θ2, θ3, t) # θ3 is the current minimal model is at
    α, β = t, t
    perturbed = [[Dict() for i in 1:length(t)] for i in 1:length(t)]
    @showprogress for (l, array) in θ3
        for i in 1:length(t)
            for j in 1:length(t)
                #ξ = α[i] .* θ0[arg] .+ (1-α[i]) .* θ1[arg]
                ϕ = α[i] * copy(norm(θ3[l]))[1] / copy(norm(θ2[l]))[1] .* θ2[l] .+ (1-α[i]) .* θ3[l]
                #θ = β[j] .* ξ .+ (1-β[j]) .* ϕ
                θ = β[j] * copy(norm(θ3[l]))[1] / copy(norm(θ0[l]))[1] .* θ0[l] .+ (1-β[j]) * copy(norm(θ3[l]))[1] / copy(norm(ϕ))[1] .* ϕ
                push!(perturbed[i][j], l => θ )
            end
        end
    end
    return perturbed
end

function updateweights!(m,perturbed)
    for (l,v) in perturbed
        if l != 6
            copyto!(m.layers[l].weight.data,perturbed[l])
        else
            copyto!(m.layers[l].W.data,perturbed[l])
        end
    end
end

t = [Float32(t) for t in -2:0.1:2]
minimal = Dict(1 => m.layers[1].weight.data, 3 => m.layers[3].weight.data, 6 => m.layers[6].W.data)

Dict{Int64,Array{Float64,N} where N} with 3 entries:
  3 => [0.119646 0.129721; 0.118327 0.132288]…
  6 => [-0.0533239 0.000418849 … 0.0763203 -0.130841; 0.049489 0.0835781 … 0.00…
  1 => [0.128144 0.128747; 0.0930793 0.134757]…

In [6]:
per = perturb(m)
perturbed1 = perturbweights(m, per, t);
losses = []

@showprogress for p in perturbed1
    updateweights!(m, p)
    los = crossentropy(m(tX), tY)
    push!(losses, los.data)
end

[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:17[39m


In [7]:
plot(t, losses)

In [14]:
#3 dim
θ0, θ1, θ2, θ3 = perturb(m), perturb(m), perturb(m), minimal
t = [Float32(t) for t in -1:0.1:1]
perturbed = pw_2d_bilinear(θ0, θ1, θ2, θ3, t)

#2d perturb
tl = length(t)
losses = Array{Float32,2}(undef,tl,tl)
@showprogress for i in 1:tl
    for j in 1:tl
        updateweights!(m,perturbed[i][j])
        los = crossentropy(m(tX), tY)
        losses[i,j] = los.data
    end
end

[32mProgress: 100%|█████████████████████████████████████████| Time: 0:03:03[39m


In [15]:
trace = surface(z=log10.(losses))
plot(trace, Layout(autosize=false))

https://arxiv.org/abs/1712.09913

![deeploss.PNG](attachment:deeploss.PNG)

![skiploss.PNG](attachment:skiploss.PNG)