In [None]:
using DataFrames
using Distributions: cquantile, TDist
using CSV
using HDF5
using Plots
using Printf
using StatsPlots
using Statistics

include("ActiveDomainAdaptation.jl")
using .ActiveDomainAdaptation
include("DataSets.jl")
using .DataSets

In [None]:
X_mnist_train, y_mnist_train, X_mnist_test, y_mnist_test = prepare_mnist(get_mnist("data/mnist"))
size(X_mnist_train), size(X_mnist_test)

In [None]:
function process_csv(df)
    gdf = groupby(df, :round)
    acc_mean = combine(gdf, :accuracy => mean).accuracy_mean
    acc_std = combine(gdf, :accuracy => std).accuracy_std
    n = 30
    α = 0.01    # 99 % confidence interval
    t = cquantile(TDist(n - 1), α / 2)
    ci = t * (acc_std / sqrt(n))
    return acc_mean, ci
end

df_random = DataFrame(CSV.File("data/random.csv"))
df_entropy = DataFrame(CSV.File("data/entropy.csv"))
df_mcdropout = DataFrame(CSV.File("data/mcdropout.csv"))

mean_random, ci_random = process_csv(df_random)
mean_entropy, ci_entropy = process_csv(df_entropy)
mean_mcdropout, ci_mcdropout = process_csv(df_mcdropout)

scatter(0:30, mean_random, yerror=ci_random, label="Random Sampling",
    legend_position=:bottomright, xlabel="Round", ylabel="Accuracy")
scatter!(0:30, mean_entropy, yerror=ci_entropy, label="Entropy Sampling")
scatter!(0:30, mean_mcdropout, yerror=ci_mcdropout, label="MC Dropout Sampling")

In [None]:
mean_mcdropout[end], mean_entropy[end]

In [None]:
ci_mcdropout[end], ci_entropy[end]

In [None]:
@df df_random boxplot(:round, :accuracy, label="Random Sampling",
    legend_position=:bottomright, xlabel="Round", ylabel="Accuracy")
@df df_entropy boxplot!(:round, :accuracy, label="Entropy Sampling")
@df df_mcdropout boxplot!(:round, :accuracy, label="MC Dropout Sampling")

In [None]:
file = "data/human_labeller.hdf5"
h5open(file, "w") do datafile
    write(datafile, "X", X_mnist_train)
    write(datafile, "y", y_mnist_train)
end

function human_labeller(model, X_query, index_query, round)
    prob_query = probability(model, X_query)
    mi_query = mc_mutual_information(prob_query, model.T)
    h5open(file, "r+") do datafile
        write(datafile, @sprintf("index_query_%d", round), index_query)
        write(datafile, @sprintf("prob_query_%d", round), prob_query)
        write(datafile, @sprintf("mi_query_%d", round), mi_query)
    end
    dataset = readline()
    datafile = h5open(file, "r")
    index_label = read(datafile, dataset)
    close(datafile)
    return index_label
end

rounds_human, accuracies_human = simulate_al(
    mcdropout_sampling, human_labeller, MCLeNetVariant("mclenet.bson", 20),
    X_mnist_train, y_mnist_train, X_mnist_test, y_mnist_test,
    n_query=100)