## MONGOOSE Prototype 1
### Goal: Implement smart hash-table update scheduler

Maintain two copies of the parameters $W$ (trainable params) and $V$ (lagged copy of V).
Maintain a list $S$ of the nodes $w_i$ in $W$ where $||w_i - v_i|| > \epsilon$ where $\epsilon$ is some error threshold parameter (e.g. $v_i$ are stale). So $S$ tracks the number of node indices that are stale. 

Everytime we query the hash table, we also compute $S$. If $|S| > \rho$ then we initiate an LSH table update of just those stale nodes, and also include nodes that are close to the decision-boundary $\epsilon$ if there are many nodes close to the decision boundary, e.g. $S \leftarrow w_i, if ||w_i - v_i|| > \epsilon / 2 \text{ and } i > \rho$ where $\rho$ is some parameter, e.g. if 5% of the nodes are near the decision boundary then do a full update of those too, because when there are a lot nodes close to the decision-boundary then they are likely to become stale soon anyway.

So every LSH query we compute $S$ and then use that to decide whether to initiate a hashtable update of selected stale nodes. When we update the nodes in $S$, we also update the lagged copy of the parameters $V$ for those nodes so that their synchronized with the main parameters $W$.

In [None]:
using Revise
using Flux
using Zygote
using MLDatasets
using LSHFunctions
using DataStructures
using Plots;
using Profile;
using StatProfilerHTML;
using BenchmarkTools
using LinearAlgebra;
using SparseArrays
using Random;
include("./Optim.jl");
using .Optim

In [None]:
#Parameters
batch_size = 8
k = 5  #k,L= 7,10
L = 15;
bin_size = 20
sample_prop = 0.05 # 5%, proportion of nodes to sample from matrix
update_thresh_eps = 0.05 # minimum difference between lagged vector and trained vector to consider stale, ∈ [0,0.1]
update_thresh_num = 0.1; # percentage of nodes that are stale before triggering update

In [None]:
function onehotencode(y)
    t = zeros(Float64,10)
    t[y[]+1] = 1.0
    return t
end

In [None]:
train_x, train_y = MNIST.traindata();
train_x = permutedims(train_x,(3,2,1))
x_train = reshape(train_x,(size(train_x)[1],prod(size(train_x)[2:end])));
y_train = [transpose(onehotencode(y)) for y in train_y]
y_train = vcat(y_train...);
x_train = convert(Array{Float64}, x_train);

In [None]:
test_x, test_y = MNIST.testdata();
test_x = permutedims(test_x,(3,2,1))
x_test = reshape(test_x,(size(test_x)[1],prod(size(test_x)[2:end])));
y_test = [transpose(onehotencode(y)) for y in test_y]
y_test = vcat(y_test...);
x_test = convert(Array{Float64}, x_test);

In [None]:
#tid = rand(1:60000)
#println(y_train[tid,:])
#Gray.(reshape(x_train[tid,:],(28,28)))

In [None]:
moving_average(vs,n) = [sum(@view vs[i:(i+n-1)])/n for i in 1:(length(vs)-(n-1))]

In [None]:
function comb(n, len) 
    Iterators.product(fill(BitArray([0,1]), len)...) |> collect |> vec 
end

In [None]:
all_hash_codes = comb(1,k);

In [None]:
mutable struct FixLayer
    theta::Matrix #original
    bias::Matrix #original
    theta_lag::Matrix
    S::Set{Int64} #stores stale node indices
end

mutable struct LSHLayer
    theta::Matrix #original
    bias::Matrix #original
    theta_lag::Matrix #lagged copy for mongoose
    hash_funs::Vector #Vector{SimHash}
    hash_tables::Vector{Dict{Tuple,CircularBuffer{Integer}}}
    #sampled::Vector
    S::Set{Int64} #stores stale node indices
end

mutable struct ViewLayer
    thetaV::SubArray
    biasV::SubArray
end

In [None]:
#The `Layer` type constructor
function LSHLayer(in_dim::Integer,out_dim::Integer,k=6,L=6,bin_size=bin_size)
    div_n = 10
    theta = randn(in_dim,out_dim) / div_n
    bias = randn(1,out_dim) / div_n
    cols = size(theta)[2] #number of columns/nodes
    #hsh = SignALSH(k,maxnorm=5)
    hash_funs = [SimHash(k) for i in 1:L] #LSHFunction(cossim, k)
    hash_tables = Vector{Dict{Tuple,CircularBuffer{Integer}}}()
    for i in 1:L #create L hash tables
        ht_l::Dict{Tuple,CircularBuffer{Integer}} = Dict{Tuple,CircularBuffer{Integer}}((x) => CircularBuffer{Integer}(bin_size) for x in all_hash_codes)
        push!(hash_tables,ht_l)
        for j in 1:cols
            t = Flux.normalize(theta[:,j])
            hash_lj = hash_funs[i](t)
            hash_lj = Tuple(hash_lj)
            push!(hash_tables[i][hash_lj],j)
        end
    end
    return LSHLayer(theta,bias,copy(theta),hash_funs,hash_tables,Set([])), ViewLayer(view(theta,:,:),view(bias,:,:))
end

function FixLayer(in_dim::Integer,out_dim::Integer)
    div_n = 10
    theta = randn(in_dim,out_dim) / div_n
    bias = randn(1,out_dim) / div_n
    return FixLayer(theta,bias,copy(theta),Set([])), ViewLayer(view(theta,:,:),view(bias,:,:))
end

In [None]:
function sample_nodes(query::Vector, layer::LSHLayer)
    #`query` is the input vector for this layer
    S = Set{Int64}()
    num_ht = layers[1].hash_funs |> length
    cols = size(layer.theta)[2]
    maxN::Int64 = (sample_prop * cols) |> round
    while length(S) < maxN
        i = rand(1:num_ht) #get random hash table index
        # compute hash of query using each hashfun
        q_hash = layer.hash_funs[i](query) |> Tuple 
        matches = layer.hash_tables[i][q_hash]
        union!(S,matches)
        if isempty(S)
            push!(S,rand(1:cols))
        end
    end
    return S |> collect
end

In [None]:
"""
Input: an LSH layer
Output: vector of node indices

This function will compare the trained parameters and the lagged copy of the parameters to find those nodes that
are stale, i.e. where the lagged copy is significantly different from the main copy.
"""
function check_stale(layer)
    ϵ = (update_thresh_eps / 2.0)
    rows, cols = size(layer.theta)
    notS = setdiff(1:cols,layer.S)
    maxN::Int64 = (0.2 * rows) |> round
    rid = rand(1:rows,maxN)
    td = layer.theta[rid,:] .- layer.theta_lag[rid,:]
    S::Set{Int64} = Set{Int64}([])
    S₂::Set{Int64} = Set{Int64}([])
    t = (update_thresh_num * cols) |> round
    close_bound = ( (0.25) * ϵ )
    for i in notS
        a = td[:,i]
        z = sqrt(dot(a,a))
        if z >= ϵ
            push!(S, i)
        elseif z >= close_bound #if there are a lot of points close to decision boundary
            push!(S₂, i)
        end
    end
    if length(S₂) > (1.5 * t)
        S = (S ∪ S₂)
    end
    
    return S
end

In [None]:
#layer1.theta_lag[:,rand(1:1000,90)] .+= randn(784)/100.0;

In [None]:
#check_stale(layer1)
#parallel: 8 ms, array comprehension: 34 ms, for-loop: 35 ms, notS: 35ms
#396 us

In [None]:
#Flux.trainable(a::LSHLayer) = (a.thetaV,a.biasV)
#Flux.trainable(a::FixLayer) = (a.thetaV,a.biasV)

In [None]:
function get_view(a::Matrix,r,c)
    #r,c are arrays, if empty [] then return all elements
    d1,d2 = size(a);
    return view(a, isempty(r) ? (:) : r, isempty(c) ? (:) : c)
end

In [None]:
# Define the function that runs an LSH layer
function (m::LSHLayer)(X::Matrix,rows::Vector,V::ViewLayer)
    Zygote.ignore() do
        cols = sample_nodes(vec(X),m) #rand(1:size(m.theta)[2],90)
        cols = sort(cols);
        V.thetaV = get_view(m.theta,rows,cols)
        V.biasV = get_view(m.bias,[],cols)
        #m.sampled = cols
    end
    y = X * V.thetaV .+ V.biasV
end
#no activation function applied yet

function (m::FixLayer)(X::Matrix,rows::Vector,V::ViewLayer)
    Zygote.ignore() do
        V.thetaV = get_view(m.theta,rows,[])
        V.biasV = get_view(m.bias,[],[])
    end
    y = X * V.thetaV .+ V.biasV
end

# Initialize Layers

# Initialize Layers

In [None]:
dim1,dim2,dim3 = 784, 1000, 10
layer1, layer1view = LSHLayer(dim1,dim2, k, L) #k,L
layer2, layer2view = FixLayer(dim2,dim3)
layers = [layer1, layer2];
layerviews = [layer1view, layer2view];

In [None]:
function get_view_indices(layerview)
   return layerview.thetaV.indices[2] 
end

In [None]:
#get_view_indices(layerviews[1])

In [None]:
function model(X::Matrix,layers::Vector,layerviews::Vector)
    #layer 1
    X = Flux.normalise(X;dims=ndims(X), ϵ=1e-5)
    A1 = layers[1](X,Vector{Integer}[],layerviews[1])
    A1 = NNlib.relu.(A1)
    #layer 2
    #rows1 = isempty(layers[1].cols) ? sample_nodes(vec(X),layers[1]) : layers[2].cols
    #=
    A1 = Flux.normalise(A1;dims=ndims(A1), ϵ=1e-5)
    A2 = layers[2](A1,layers[1].sampled)
    #println(size(A2),length(n2))
    A2 = NNlib.sigmoid.(A2)
    =#
    A2 = Flux.normalise(A1;dims=ndims(A1), ϵ=1e-5)
    A2 = layers[2](A2,get_view_indices(layerviews[1]),layerviews[2])#layers[1].sampled
    A2 = NNlib.softmax(A2,dims=2)
end

In [None]:
function model(X::Matrix,layers::Vector,layerviews::Vector)
    #layer 1
    X = Flux.normalise(X;dims=ndims(X), ϵ=1e-5)
    A1 = layers[1](X,Vector{Integer}[],layerviews[1])
    A1 = NNlib.relu.(A1)
    #layer 2
    A2 = Flux.normalise(A1;dims=ndims(A1), ϵ=1e-5)
    A2 = layers[2](A2,get_view_indices(layerviews[1]),layerviews[2])#layers[1].sampled
    A2 = NNlib.softmax(A2,dims=2)
end

In [None]:
lossfn(ŷ::Vector,y::Vector) = -1.0 * LinearAlgebra.dot(log.(ŷ),y)

function lossfn2(x::Matrix,y::Vector,layers::Vector,layerviews::Vector)
    ŷ = vec(model(x, layers, layerviews));
    l = -1.0 * LinearAlgebra.dot(log.(ŷ),y)
end

In [None]:
function update_htables(S::Set,ms::Vector)
    u = 0
    for layer in ms
        num_ht = length(layer.hash_funs)
        cols = size(layer.theta)[2]
        ns = length(S)
        t = (update_thresh_num * cols) |> round
        if ns >= t
            u += 1
            while !isempty(S)
                j = pop!(S)
                for i in 1:num_ht #iterate tables
                    hash_lj = layer.hash_funs[i](layer.theta[:,j])
                    hash_lj = Tuple(hash_lj)
                    push!(layer.hash_tables[i][hash_lj],j)
                    layer.theta_lag[:,j] = copy(layer.theta[:,j])
                end
            end
        end
        layer.S = S
    end
    
    return u
end

In [None]:
lossfn(ŷ::Vector,y::Vector) = -1.0 * LinearAlgebra.dot(log.(ŷ),y)

function lossfn2(x::Matrix,y::Vector,layers::Vector,layerviews::Vector)
    ŷ = vec(model(x, layers, layerviews));
    l = -1.0 * LinearAlgebra.dot(log.(ŷ),y)
end

In [None]:
#S = [1,50,90,112,145,240,300,301,500,505,506,511];
#g = Zygote.gradient(w -> lossfn(vec(model(randn(1,784),w,S)),[1.0,0,0,0,0,0,0,0,0,0]),layers)
#g = Zygote.gradient((ypred,ytrue) -> lossfn(ypred,ytrue),(A2,batch_y))
#println(g);

In [None]:
function get_params_view(layerviews::Vector)
    #idx: [rows1,cols2,rows2,cols2]
    p = []
    #i = 0
    for m in layerviews
        #theta = get_view(m.theta,m.rows,m.cols)
        #bias = get_view(m.bias,[],m.cols)
        append!(p,[m.thetaV,m.biasV])
        #i+=1
    end
    return p
end

In [None]:
function convert_grads(gs)
    g = []
    for l in gs[1]
        append!(g,[l[][:thetaV],l[][:biasV]])
    end
    g
end

In [None]:
model(randn(1,784), layers, layerviews)

In [None]:
#lossfn2(randn(1,784),y_train[1,:],layers,layerviews)

In [None]:
function gen_layerviews(layers)
    V = []
    for l in layers
        push!(V, ViewLayer(view(l.theta,:,:),view(l.bias,:,:)))
    end
    V
end

In [None]:
function train(x_train,y_train,layers,epochs=200000)
    N₀ = 50
    λ₀ = 0.6
    lr = 0.001
    update_freq = round.([N₀ * exp(λ₀*x) for x in LinRange(0,5,epochs)])
    #opt = SGDM(lr, 0.9, [])#ADAM(lr)
    ps = get_params_view(layerviews);
    opt = Optim.ADAM(ps, lr)
    lossarr = []
    updates = 0
    lk = ReentrantLock()
    Threads.@threads for i in 1:epochs
        rid = rand(1:60000)
        x = float(x_train[rid,:])
        x = reshape(x,(1,784))
        yt = y_train[rid,:]
        layerviews = gen_layerviews(layers)
        l, gs = withgradient((mv) -> lossfn2(x,yt,layers,mv), layerviews)
        gs = convert_grads(gs)
        #foreach(x -> clamp!(x, -0.1, 0.1), gs)
        ps = get_params_view(layerviews);
        step!(opt,ps,gs)
        S = check_stale(layers[1])
        (i % 50 == 0) ? print(" $(length(S)) ") : false
        updates += update_htables(S,layers[1:1])
        #=
        lock(lk) do
            #do something
            push!(lossarr,l)
        end
        =#
    end
    println("Num hash table updates: $updates")
    return lossarr
end

In [None]:
#opt = SGDM(0.001, 0.9, [])

In [None]:
#=gs_ = gradient((mv) -> lossfn2(reshape(x_train[1,:],(1,784)),y_train[1,:],layers,mv), layerviews)
gs = convert_grads(gs_)
ps = get_params_view(layerviews)
#step3!(opt,ps,gs)=#

In [None]:
@time train(x_train,y_train,layers,1000)
#12.3 s with check_stale, 191 ms without check_stale
# 2.4 s for 1k iters

In [None]:
@time begin
    lossarr = train(x_train,y_train,layers,25000);
end;
# ~141 seconds for 250k iter with 2-layers and dim2 = 1000 
# 72 seconds for 25k iter
# 32.9 seconds 25k no updates
# 56 seconds 25k w/ updates 63 ht updates

In [None]:
plot(moving_average(lossarr,5000)) #y top is > 3, x right is 2 x 10^5
#plot(lossarr[1:2250])

In [None]:
function test_acc(xs,ys)
    ncorr = 0
    ntot = size(ys)[1]
    #lk = ReentrantLock()
    for i in 1:ntot
        #rid = rand(1:size(xs)[1])
        x = float(xs[i,:])
        x = reshape(x,(1,784))
        yt = ys[i,:]
        ŷ = vec(model(x, layers, layerviews));
        if argmax(ŷ) == argmax(yt)
            ncorr += 1
        end
    end
    println(100 * (ncorr/ntot))
end

In [None]:
test_acc(x_test,y_test)

In [None]:
# 63.42 -50k, 1000dim, without hash table updates
# 63.83, with hash table updates
# 81.56 with %50 hash updates
# 80.21 without hash updates
# 80.16 with exp decy hash updates
# 81.3  with %25 updates
# 72.78% without hashing (random columns)
# 80.76 with hashing and sorted columns
# 84.95 div_n = 5
# 90.35 with k,L=5,7 div_n=5
# 90.66 k,L = 5,10
# 90.36     k,L = 5,20 div_n=5
# 93.55 k,L=5,10 with SGDM
# 97.06 with 2 fix layer and SGDM, but 1091.35 seconds run-time
# 91.75  ADAM, lr = 0.0001, 250k iter
# 94.47  ADAM, lr=0.01, 250k iter
# 93.06  ADAM, lr=.1, 250k iter
# 94.94  ADAM, lr=0.0001, HT %50, 250k iter - simhash
# 95.679 ADAM, lr=0.0001, exp decay HT, 250k iter, bin_size = 20
# 96.2 ADAM, lr=0.0001, HT%25, 250k iter, bin_size = 30 (403 seconds)
# 93.39  ADAM, lr = 0.001 exp decay 25k iter, bin_size = 30 ( 72 seconds, 32 seconds if bin size = 10 and acc 91.14)
# 96.76  ADAM, lr = 0.001 exp decay 250k iter, bin_size = 40  added normalization to weights

## Archive

```julia
function get_params_view(layers,idx)
    #idx: [rows1,cols2,rows2,cols2]
    p = []
    i = 0
    for layer in layers
        theta = get_view(layer.theta,idx[2*i+1],idx[2*i+2])
        bias = get_view(layer.bias,[],idx[2*i+2])
        append!(p,[theta,bias])
        i+=1
    end
    return p
end
```

```julia
function (m::Layer)(X::Matrix,rows,cols)
    e1 = isempty(cols)
    e2 = isempty(rows)
    if e1 & e2 # neither rows nor cols given
        y = X * m.theta .+ m.bias
    elseif e2 #cols given alone
        y = X * (@view m.theta[:,cols]) .+ (@view m.bias[:,cols]);
    elseif e1 #rows given alone
        y = X * (@view m.theta[rows,:]) .+ m.bias;
    else # rows and cols given
        y = X * (@view m.theta[rows,cols]) .+ (@view m.bias[:,cols]);
    end
    return y
end
```

```julia
function rescale(x_)
    x = copy(x_)
    mn = abs(minimum(x))
    x = x .+ mn
    mx = maximum(x)
    x = x ./ mx
    x
endfunction rescale(x_)
    x = copy(x_)
    mn = abs(minimum(x))
    x = x .+ mn
    mx = maximum(x)
    x = x ./ mx
    x
end

get_nonzero(A) = [i for (i,x) in enumerate(A) if x != 0]

function WTAH(k::Integer)
    function output(x::Vector,k::Integer)
        r = BitArray(undef,k)
        N = 2
        l = length(x)
        nz = get_nonzero(x)
        for i in 1:k
            d = shuffle(nz)
            x̂ = x[d]
            s = x̂[1:N]
            y = (argmax(s)-1)
            r[i] = y
        end
        r
    end
    x -> output(x,k)
end
```

# Baseline

In [None]:
base_model = Chain(Dense(28^2, dim2, relu), Flux.normalise, Dense(dim2, 10), softmax)

In [None]:
base_model(x_train[1,:])

In [None]:
#collapse
function base_train(epochs=100)
    baseopt = Flux.ADAM(0.0001); 
    for i in 1:epochs
        rid = rand(1:60000)
        x = float(x_train[rid,:])
        #x = reshape(x,(1,784))
        yt = y_train[rid,:]
        ps = Flux.params(base_model)
        gs = gradient(ps) do
            ŷ = base_model(x)
            lossfn(ŷ,yt)
        end
        Flux.update!(baseopt,ps,gs)
    end
end

In [None]:
@time begin
    base_train(25000)
end
#25k iter: 111 seconds to 95% acc, max acc is 97.5 (if train to convergence)

In [None]:
function base_test_acc(xs,ys)
    ncorr = 0
    ntot = size(ys)[1]
    for i in 1:ntot
        #rid = rand(1:size(xs)[1])
        x = float(xs[i,:])
        #x = reshape(x,(1,784))
        yt = ys[i,:]
        ŷ = vec(base_model(x));
        if argmax(ŷ) == argmax(yt)
            ncorr += 1
        end
    end
    println(100 * (ncorr/ntot))
end

In [None]:
base_test_acc(x_test,y_test)