In [None]:
using MLToolkit, Logging, ProgressMeter, Statistics, PyCall, ArgParse, Printf, Dates, LaTeXStrings
import Random
const tbX = pyimport("tensorboardX")

MLToolkit.Knet.gpu(false)
const Normal = BatchNormal{FT}

# Auto-reload packages imported afterwards if in Jupyter
@jupyter using Revise

using RRVAE

;

In [None]:
settings = ArgParseSettings()

@add_arg_table settings begin
    "--alpha"
        arg_type = Real
        default = -1
    "--k_max"
        arg_type = Int
        default = -1
    "--dataset"
        arg_type = String
        required = true
        range_tester = (d->d=="synth"||d=="synth_large"||d=="mnist"||d=="fmnist")
    "--tr_sz"
        arg_type = Int
        default = -1
    "--te_sz"
        arg_type = Int
        default = -1
    "--h_dim"
        arg_type = Int
        default = -1
    "--batch_size"
        arg_type = Int
        required = true
    "--n_epochs"
        arg_type = Int
        default = -1
    "--bnp"
        arg_type = String
        required = true
        range_tester = (b->b=="ibp"||b=="crp")
    "--inf"
        arg_type = String
        required = true
        range_tester = (i->i=="mf"||i=="s"||i=="rrs")
    "--obs"
        arg_type = String
        required = true
        range_tester = (o->o=="gauss"||o=="ber")
    "--is_deep"
        arg_type = Bool
        required = true
    "--lr"
        arg_type = Float64
        default = 0.001
    "--beta1"
        arg_type = Float64
        default = 0.99
    "--beta2"
        arg_type = Float64
        default = 0.999
    "--gpu_id"
        arg_type = Int
        default = -1
    "--eval_every"
        arg_type = Int
        default = -1
    "--check_every"
        arg_type = Int
        default = -1
    "--continue"
        action = :store_true
    "--store"
        action = :store_true
end

args_str_with_line_breaks = """
--dataset synth --batch_size 200 --bnp ibp --inf rrs --obs gauss --is_deep true --gpu_id 3
--store"""

args_str = replace(args_str_with_line_breaks, "\n" => " ")

args = parse_args(isjupyter() ? args_str : ARGS, settings; as_symbols=true)

# Update args from pre-set if not specified
PRESETARGS = Dict(
    "synth"       => Dict(:alpha=>04.0, :k_max=>09, :h_dim=>050, :tr_sz=>02_400, :te_sz=>00_400, 
                          :n_epochs=>(args[:inf] == "rrs" ? 3_000 : 1_500), :eval_every=>25, :check_every=>50),
    "synth_large" => Dict(:alpha=>08.0, :k_max=>25, :h_dim=>150, :tr_sz=>12_000, :te_sz=>02_000, 
                          :n_epochs=>(args[:inf] == "rrs" ? 0_600 : 0_300), :eval_every=>05, :check_every=>010),
    "mnist"       => Dict(:alpha=>10.0, :k_max=>50, :h_dim=>500, :tr_sz=>60_000, :te_sz=>10_000, 
                          :n_epochs=>(args[:inf] == "rrs" ? 0_150 : 0_060), :eval_every=>01, :check_every=>002),
    "fmnist"      => Dict(:alpha=>20.0, :k_max=>50, :h_dim=>500, :tr_sz=>60_000, :te_sz=>10_000, 
                          :n_epochs=>(args[:inf] == "rrs" ? 0_150 : 0_060), :eval_every=>01, :check_every=>002)
)
for preset_arg in keys(PRESETARGS["synth"])
    if !(preset_arg in keys(args)) || args[preset_arg] == -1
        args[preset_arg] = PRESETARGS[args[:dataset]][preset_arg]
    end
end

@jupyter println(args_str)

println("Parsed args:")
for (arg, val) in args
    println("$("$(@sprintf("%17s", arg))")  =>  $val")
end

# NOTE: Ckpt FMT = (jupyter,filename)/bnp-inf-arch-obs/args_str_flat

args_str_flat, args_str_flat_method, args_str_flat_tr = flatstr(args)
file_name = replace(basename(@__FILE__), ".jl" => "")
@jupyter file_name = "jupyter"
# @jupyter file_name = "bnpvae"

exp_name = "$file_name/$args_str_flat_method/$args_str_flat"

log_dir = "../log/$exp_name"
if !ispath(log_dir) 
    mkpath(log_dir) 
end

# Create logger and set it to global
logger_io = open("$log_dir/log.txt", "w+")
logger = CombinedLogger(SimpleLogger(stderr), SimpleLogger(logger_io))
global_logger(logger)

# Log parsed args
@info "Parsed args" args...

model_path = "$log_dir/model.jld2"

# -1 means not specified. Knet will use any GPU which is free then.
args[:gpu_id] != -1 && Knet.gpu(args[:gpu_id])
@info "On GPU" Knet.gpu()

@info "Exp info" exp_name log_dir

;

In [None]:
# Get data; use `@doc get_data` to see how `is_save=true` behaves
x_tr, x_te, features = get_data(args; is_save=true)
@info "Loaded data" size(x_tr)
x_dim = size(x_tr, 1)

# Shuffle dataset
Random.seed!(1234)

x_tr = x_tr[:,Random.randperm(args[:tr_sz])]
x_te = x_te[:,Random.randperm(args[:te_sz])]

@jupyter begin 
    plt.figure(figsize=(16, 9))
    ax = plot_grayimg(x_tr, 3, 20)
    ax."set_title"("Data samples")

    if features != nothing
        plt.figure(figsize=(16, 9))
        ax = plot_grayimg(features)
        ax."set_title"("Features")
    end
end

batch_size = args[:batch_size]
tr_loader = BatchDataLoader(batch_size, x_tr; atype=AT)
te_loader = BatchDataLoader(batch_size, x_te; atype=AT)

;

## Yo

In [None]:
if args[:continue] && ispath(model_path)
    model_dict = Knet.load("$model_path")
    vae = model_dict["vae"]
    epoch_end = model_dict["epoch"]
    @info "Loaded `$model_path`"
else
    if ispath(model_path)
        @warn "$model_path existis!!"
    end
    vae = if args[:bnp] == "ibp"
        IBPVAE(x_dim, args)
    elseif args[:bnp] == "crp"
        CRPVAE(x_dim, args)
    else
        throw("[IBPVAE] Unkown BNP model: $bnp")
    end
    if args[:inf] == "rrs"
        initoptim!(vae, (; kwargs...) -> DynamicAdam(true; kwargs...); lr=args[:lr], beta1=args[:beta1], beta2=args[:beta2])
        vae.rho.opt = SGD(; lr=1.5args[:lr])
    else
        initoptim!(vae, Adam; lr=args[:lr], beta1=args[:beta1], beta2=args[:beta2])
    end
    epoch_end = 0
end

@info vae

# Note: log FMT = (jupyter,filename)/inference/args_str_flat/current_time
writer_dir = log_dir
@jupyter writer_dir *= "/$(Dates.format(now(), "dd-u-yyyy-H-M-S"))"
                
writer = tbX.SummaryWriter(writer_dir)

;

In [None]:
epoch_start = epoch_end + 1
epoch_end = epoch_start + args[:n_epochs] - 1

t_tr, t_te = 0, 0

t_total = @elapsed let pm=Progress(args[:n_epochs], desc="Training", barlen=31)
    for epoch = epoch_start:epoch_end
        global t_tr, t_te
        t_tr += @elapsed avg_loss = train!(vae, tr_loader; writer=writer, epoch=epoch)
        isnan(avg_loss) && @warn "Loss is NaN!"
        writer."add_scalar"("dataset/neg_elbo", avg_loss, epoch)

        avg_eval = missing
        if epoch % args[:eval_every] == 0
            t_te += @elapsed avg_eval = evaluate(vae, te_loader)
            writer."add_scalar"("dataset/iwae_te", avg_eval, epoch)
            # Make visualisations
            x_batch = first(te_loader)
            write_vae(writer, vae, x_batch, epoch)
        end

        if epoch % args[:check_every] == 0
            Knet.save("$model_path", "vae", vae, "epoch", epoch)
            @info "Saved to `$model_path`" epoch
        end

        @script ProgressMeter.next!(pm; showvalues = [(:epoch, epoch), (:loss, avg_loss), (:eval, avg_eval)])
    end
end
@info "Time" t_total t_tr t_te

@info "Finished"
writer."add_scalar"("dataset/iwae_tr", evaluate(vae, tr_loader), epoch_end)

if args[:store]
    Knet.save("$model_path", "vae", vae, "epoch", epoch_end)
end

;

In [None]:
# For annotation purpose
k_mode_rr = nothing
;

In [None]:
# Collect Z for the whole testset
Z_list = []
for x_batch in te_loader
    dist_Z = vae(x_batch, Val(false))[3]
    push!(Z_list, Array(mean(dist_Z)'))
end

# Make sure Z has the same number of columns by padding 0s
# This is needed to be comptaible with RR methods
n_cols_list = map(Z -> size(Z, 2), Z_list)
n_cols_max = max(n_cols_list...)
for i in 1:length(Z_list)
    Z_list[i] = hcat(Z_list[i], zeros(size(Z_list[i], 1), n_cols_max - size(Z_list[i], 2)))
end
Z = vcat(Z_list...)
if size(Z,2) < args[:k_max]
    Z = hcat(Z, zeros(size(Z, 1), args[:k_max]-size(Z, 2)))
end

K_max = sum(sum(Z .> 0.01; dims=1) .> 0)
@info "Non-negligible posterior activation probability" K_max

# Output PDFs
ps = make_Z_plot(Z; k_mode_rr=k_mode_rr, subplotting=false)
for (i, fig_name) in enumerate(["act", "freq", "hist"])
    writer."add_figure"("plots/$fig_name", ps[i])
    ps[i]."savefig"("$log_dir/$fig_name.pdf", bbox_inches="tight")
end
@info "Plots in PDF format saved to $log_dir/"

# Vis in Jupyter
@jupyter make_Z_plot(Z; k_mode_rr=k_mode_rr, subplotting=true)

# Report IWAE
@info "IWAE on training and testing set" evaluate(vae, tr_loader) evaluate(vae, te_loader)

# Vis features
if args[:is_deep]
    mu_list = []
    for x_batch in te_loader
        dist_A = vae(x_batch, Val(false))[4]
        push!(mu_list, Array(mean(dist_A)'))
    end
    mu = vcat(mu_list...)
    A_mean_avg = vec(mean(sqrt.(mu.^2); dims=1))
    
    p = plt.figure(figsize=(3,2))
    plt.bar(1:length(A_mean_avg), A_mean_avg)
    !(k_mode_rr === nothing) && plt.axvline(x=k_mode_rr, color="r", linestyle="--")
    plt.xlabel("k-th feature")
    plt.ylabel("L1 norm of " * L"$q_{A_k}$")
    p."savefig"("$log_dir/l2norm.pdf", bbox_inches="tight")
else
    p = plt.figure()
    k = args[:inf] == "rrs" ? mode(value(vae.rho).lnpd) : args[:k_max]
    ax = plot_grayimg(Array(vae.decoder.A[:,1:k]))
    p."savefig"("$log_dir/features.pdf", bbox_inches="tight")
end

# Vis posterior of truncation level
if args[:inf] == "rrs"
    # Report some stats
    k_list = rand(value(vae.rho).lnpd, 1_000)
    @info "" mode(value(vae.rho).lnpd) mean(k_list) std(k_list)
    
    rho = value(vae.rho)
    ks = 1:(args[:k_max] > 1 ? args[:k_max] : vae.k_max)
    ps = pdf.(Ref(rho), ks)
    p = plt.figure(figsize=(3,3))
    plt.plot(ks .- 1, ps)
    plt.xlabel("k")
    plt.ylabel(L"P(K^*=k)")
    p."savefig"("$log_dir/dist_k.pdf", bbox_inches="tight")
end
;

In [None]:
# Not continue in script mode
@script exit()
;

## Other plots in the paper

In [None]:
k_list = [1, 5, 10, 11, 12, 13, 14, 15, 20, 30, 50]

iwae_list_tr = [evaluate(vae, tr_loader; k_max=k_max) for k_max in k_list]
iwae_list_te = [evaluate(vae, te_loader; k_max=k_max) for k_max in k_list]
@info "" iwae_list_tr iwae_list_te


p = plt.figure(figsize=(3,2))
plt.plot(k_list, iwae_list_tr, label="Tr. set", "--", alpha=0.5)
plt.plot(k_list, iwae_list_te, label="Te. set", "-+", alpha=0.5)
!(k_mode_rr === nothing) && plt.axvline(x=k_mode_rr, color="r", linestyle="--")
plt.xlabel("k_masked")
plt.ylabel("IWAE")
plt.legend()
p."savefig"("$log_dir/masked_iwae.pdf", bbox_inches="tight")
;

In [None]:
k_list = [10, 20, 50]

time_list = [[@elapsed evaluate(vae, te_loader; k_max=k_max) for k_max in k_list] for _ = 1:3]
@info "" time_list

m = mean(time_list)
s = std(time_list)

p = plt.figure(figsize=(3,2))
plt.plot(k_list, m, label="mean")
plt.plot(k_list, m - s, "--", color="grey", label="std")
plt.plot(k_list, m + s, "--", color="grey")
!(k_mode_rr === nothing) && plt.axvline(x=k_mode_rr, color="r", linestyle="--")
plt.xlabel("k_masked")
plt.ylabel("time (s)")
plt.xticks(k_list)
plt.legend()
p."savefig"("$log_dir/time.pdf", bbox_inches="tight")
;