In [None]:
using CategoricalArrays
using HDF5
using Flux
using Plots
using Printf
using Random

In [None]:
include("ADA.jl")
using .ADA

In [None]:
function entr(prob)
    product = prob .* log2.(prob)
    product[isnan.(product)] .= 0.0
    dropdims(-sum(product, dims=1), dims=1)
end

In [None]:
n_query = 50000

In [None]:
X, X_test, y_test = get_dr16q("data/dr16q_superset.hdf5")
X_gpu = gpu(X)
X_test_gpu = gpu(X_test)
n = size(X, ndims(X))

model = SZNet("models/sznet.bson")
cfrs = [cfr(y_test, predict(model, X_test_gpu))]

In [None]:
idx_pool = Vector{Int}(1:n)
idx_train = Vector{Int}(undef, 0)
y_train = Vector{Float32}(undef, 0)

In [None]:
round = 2

In [None]:
# process pool
prob = probability(model, X_gpu)
X_pool = view(X, :, idx_pool)
prob_pool = prob[:, idx_pool]
# query
# TODO random sampling: idx_query = shuffle(idx_pool)[1:n_query]
entr_pool = entr(prob_pool)    # entropy sampling
perm = sortperm(entr_pool, rev=true)    # entropy sampling
idx_query = idx_pool[perm[1:n_query]]    # entropy sampling
# human labeller
ŷ_query = Flux.onecold(prob_pool[:, idx_query], ADA.LABELS)
entr_query = entr_pool[idx_query]    # entropy sampling
# write to HDF5 file for eidein
HDF5FILE = "data/human.hdf5"
h5open(HDF5FILE, "cw") do hdf5file
    write(hdf5file, "idx_query_$round", idx_query)
    write(hdf5file, "entr_query_$round", entr_query)    # entropy sampling
    write(hdf5file, "ypred_query_$round", ŷ_query)
end

In [None]:
hdf5file = h5open(HDF5FILE, "r")
idx_label = read(hdf5file, "idx_label_$round")
close(hdf5file)
# update training set and pool
idx_train = union(idx_train, idx_label)
idx_pool = setdiff(idx_pool, idx_label)
# read training labels from cache of redshifts
hdf5cache = h5open("data/dr16q_superset_cache.hdf5", "r")
y_train = read(hdf5cache, "y_cache")[idx_train]
close(hdf5cache)
# prepare training set
y_train_categorical = cut(y_train, ADA.EDGES, labels=ADA.STR_LABELS)
y_train_onehot = Flux.onehotbatch(y_train_categorical, ADA.STR_LABELS)
X_train = X[:, idx_train]
# learning strategy
finetune!(model, X_train, y_train_onehot)
# evaluate
cfr_round = cfr(y_test, predict(model, X_test_gpu))
cfrs = vcat(cfrs, cfr_round)

## Digits

In [None]:
X_mnist_train, y_mnist_train, X_mnist_test, y_mnist_test = prepare_mnist(get_mnist("data/mnist"))
size(X_mnist_train), size(X_mnist_test)

In [None]:
file = "data/human.hdf5"
h5open(file, "w") do datafile
    write(datafile, "X", X_mnist_train)
    write(datafile, "y", y_mnist_train)
end

function human(model, X_query, index_query, round)
    entr_query = entropy(probability(model, X_query))
    h5open(file, "r+") do datafile
        write(datafile, @sprintf("index_query_%d", round), index_query)
        write(datafile, @sprintf("entr_query_%d", round), entr_query)
    end
    dataset = readline()
    datafile = h5open(file, "r")
    index_label = read(datafile, dataset)
    close(datafile)
    return index_label
end

filepaths = "models/" .*  ["lenet", "lenet2", "lenet3", "lenet4", "lenet5"] .* ".bson"
ensemble = DeepEnsembleLeNet(filepaths)
rounds_human, accuracies_human = simulate_al(
    entropy_sampling, human, ensemble,
    X_mnist_train, y_mnist_train, X_mnist_test, y_mnist_test,
    n_query=10000)