In [1]:
using Pkg
Pkg.activate(".")

using Plots, HypertextLiteral, Random
include("../src/load.jl") # load datasets
include("../src/surrogate.jl")
include("../src/run.jl") # run tests

[32m[1m  Activating[22m[39m project at `~/Research/SurrogateDistanceModels/notebooks`


test_searchgraph (generic function with 2 methods)

In [2]:
struct MaxHashSurrogate <: AbstractSurrogate
    kscale::Int
    pool::Matrix{Int32}
    
    function MaxHashSurrogate(samplesize::Integer, npools::Integer, dim::Integer, kscale::Integer)
        samplesize < 256 || throw(ArgumentError("samplesize < 256: $samplesize"))
        pool = Matrix{Int32}(undef, samplesize, npools)
        perm = Vector{Int32}(1:dim)
      
        for i in 1:npools
            randperm!(perm)
            pool[:, i] .= view(perm, 1:samplesize)
        end
        
        new(kscale, pool)
    end
end

samplesize(M::MaxHashSurrogate) = size(M.pool, 1)
npools(M::MaxHashSurrogate) = size(M.pool, 2)
kscale(M::MaxHashSurrogate) = M.kscale

function encode(M::MaxHashSurrogate, vout, v)
    for i in eachindex(vout)
        #vout[i] = maximum(j -> v[j], view(M.pool, :, i))
        vout[i] = findmax(j -> v[j], view(M.pool, :, i)) |> last
    end
    
    vout
end

function encode(M::MaxHashSurrogate, db_::AbstractDatabase)
    D = Matrix{UInt8}(undef, npools(M), length(db_))
    
    for i in eachindex(db_)
        encode(M, view(D, :, i), db_[i])
    end

    MatrixDatabase(D)
end

function encode(M::MaxHashSurrogate, db_::AbstractDatabase, queries_::AbstractDatabase, params)
    dist = StringHammingDistance()
    db = encode(M, db_)
    queries = encode(M, queries_)
    params["surrogate"] = "MaxHash"
    params["samplesize"] = samplesize(M)
    params["npools"] = npools(M)
    params["kscale"] = kscale(M)
    
    (; db, queries, params, dist)
end



encode (generic function with 9 methods)

In [3]:
function run_experiment(D, k;
        kscalelist=[1, 8],
        npairslist=[256, 512, 1024, 2048],
        npoolslist=[32, 64, 128, 256],
        ssizelist=[4, 8, 16]
    )
    D.params["k"] = k
    Gold = test_exhaustive(nothing, D.db, D.queries, D.dist, copy(D.params), k)
    test_searchgraph(Gold, D.db, D.queries, D.dist, copy(D.params), k)

    surrogates = []
    dim = length(D.db[1])
    for kscale in kscalelist
        # push!(surrogates, BinaryHammingFixedPairs(kscale))
        for npairs in npairslist
            push!(surrogates, BinaryHammingSurrogate(kscale, npairs, dim))
        end
        
        for ssize in ssizelist, npools in npoolslist
            push!(surrogates, MaxHashSurrogate(ssize, npools, dim, kscale))
        end
    end

    for E in surrogates
        H = encode(E, D.db, D.queries, copy(D.params))
        test_exhaustive(Gold, H.db, H.queries, H.dist, copy(H.params), k * E.kscale)
        test_searchgraph(Gold, H.db, H.queries, H.dist, copy(H.params), k * E.kscale, 0)
    end
end

run_experiment (generic function with 1 method)

In [None]:
k=32

let
    D = load_glove_400k()
    @show size(D.db.matrix), D.dist
    run_experiment(D, k)
end


let
    D = load_wit_300k()
    @show size(D.db.matrix), D.dist
    run_experiment(D, k)
end

let
    D = load_glove_1m()
    @show size(D.db.matrix), D.dist
    run_experiment(D, k)
end

let
    D = load_bigann_1m()
    @show size(D.db.matrix), D.dist
    run_experiment(D, k)
end
