In [32]:
using ColoringNames
using DataFrames
using CSVFiles
runnum = "4"
function mklogdir(name)
    logdir = joinpath(pwd(), "logs","point", name * runnum)
    mkpath(logdir)
    logdir
end



mklogdir (generic function with 1 method)

In [33]:
const many_names = load_color_nameset()
const word_embs, vocab, enc = load_text_embeddings(keep_words=many_names)
const full_data = load_munroe_data(dev_as_train=false, dev_as_test=true, encoding_ = enc)

const ord_data = order_relevant_dataset(full_data);
const extrapo_data = extrapolation_dataset(full_data);



In [34]:
create_res_df() = DataFrame(method=String[], mse=Float64[])

full_df = create_res_df()
extrapo_df = create_res_df()
ord_df = create_res_df()

function perform_evaluation(modeltype, name)
    info(name)
    
    function mdlfun(cldata)
        mdl = modeltype(enc, word_embs; n_steps=size(cldata.train.terms_padded,1))
        train!(mdl, cldata; log_dir = mklogdir(name))
        mdl
    end
    
    
    ###################
    if doextrapo(modeltype)
        extrapo_mdl = mdlfun(extrapo_data)
        println()
        @show res_extrapo = evaluate(extrapo_mdl, extrapo_data.test)
        push!(extrapo_df, Dict(:method=>name * " Extrapolating", :mse=>res_extrapo))
    end
    
    
    ######
    full_mdl = mdlfun(full_data)
    println()
    @show res_full = evaluate(full_mdl, full_data.test)
    push!(full_df, Dict(:method=>name, :mse=>res_full))
    
    #######
    println()
    @show res_ord = evaluate(full_mdl, ord_data.test)
    push!(ord_df, Dict(:method=>name, :mse=>res_ord))
    
    
    #######
    @show  res_nonextrapo = evaluate(full_mdl, extrapo_data.test)
    println()
    push!(extrapo_df, Dict(:method=>name * " Non-extrapolating", :mse=>res_nonextrapo))
    
    
    ####
    save("results/point_full.csv", full_df)
    save("results/point_extrapo.csv", extrapo_df)
    save("results/point_ord.csv", ord_df)
    
    full_mdl
end


doextrapo(::Any) = true
doextrapo(::Type{TermToColorPointEmpirical}) = false

doextrapo (generic function with 2 methods)

In [35]:
function qualitative_demo(mdl, do_oov=false)
    names = ["black", "brownish green", "brown", "brownish", "greenish", "greenish brown", "green", "red", "orange", "blue"]
    if do_oov
        append!(names, ["ish"])
    end
    plot_query(mdl, names) |> display
end

qualitative_demo(mdl::TermToColorPointEmpirical) = qualitative_demo(mdl, false)

qualitative_demo (generic function with 3 methods)

In [36]:
####################################################

In [37]:
namedmodels = [
    (TermToColorPointEmpirical, "Direct"),
    (TermToColorPointSOWE, "SOWE"),
    (TermToColorPointCNN, "CNN"),
    (TermToColorPointRNN, "RNN"),    
]

for (modeltype, name) in namedmodels
    perform_evaluation(modeltype, name)
end;

[1m[36mINFO: [39m[22m[36mDirect
[39m


res_full = evaluate(full_mdl, full_data.test) = 0.06635567f0

res_ord = evaluate(full_mdl, ord_data.test) = 0.057472266f0
res_nonextrapo = evaluate(full_mdl, extrapo_data.test) = 0.06369885f0



[1m[36mINFO: [39m[22m[36mSOWE
[39m


res_extrapo = evaluate(extrapo_mdl, extrapo_data.test) = 0.09304764f0

res_full = evaluate(full_mdl, full_data.test) = 0.066772565f0

res_ord = evaluate(full_mdl, ord_data.test) = 0.06410597f0
res_nonextrapo = evaluate(full_mdl, extrapo_data.test) = 0.076956816f0



[1m[36mINFO: [39m[22m[36mCNN
[39m


res_extrapo = evaluate(extrapo_mdl, extrapo_data.test) = 0.0893576f0

res_full = evaluate(full_mdl, full_data.test) = 0.06703785f0

res_ord = evaluate(full_mdl, ord_data.test) = 0.05773749f0
res_nonextrapo = evaluate(full_mdl, extrapo_data.test) = 0.074265406f0



[1m[36mINFO: [39m[22m[36mRNN
[39m


res_extrapo = evaluate(extrapo_mdl, extrapo_data.test) = 0.14179087f0

res_full = evaluate(full_mdl, full_data.test) = 0.07118864f0

res_ord = evaluate(full_mdl, ord_data.test) = 0.06874289f0
res_nonextrapo = evaluate(full_mdl, extrapo_data.test) = 0.13543682f0



In [None]:
noml = ColoringNames.TermToColorPointEmpirical()
train!(noml, cldata)

qualitative_demo(noml)
@show evaluate(noml, cldata.dev)

In [27]:
sowe = TermToColorPointSOWE(full_data.encoding, word_embs; n_steps=4)
train!(sowe, full_data;
    log_dir=mklogdir("sowe")
)
qualitative_demo(sowe)
@show evaluate(sowe, full_data.dev)



LoadError: [91mInterruptException:[39m

In [None]:
cnn = TermToColorPointCNN(cldata.encoding, word_embs; n_steps=4)

train!(cnn, cldata;
    log_dir=mklogdir("cnn"),
)
qualitative_demo(cnn)
@show evaluate(cnn, cldata.dev)

In [None]:
qualitative_demo(cnn)
@show evaluate(cnn, cldata.dev)

In [None]:
rnn = TermToColorPointRNN(cldata.encoding, word_embs; n_steps=4)

train!(rnn, cldata;
    log_dir=mklogdir("rnn"),
)
qualitative_demo(rnn)
@show evaluate(rnn, cldata.dev)

In [None]:
1

In [None]:
using ColoringNames: order_relevant_dataset, order_relevant_name_pairs

In [None]:
namepairs = order_relevant_name_pairs(cldata.dev);

In [None]:
hsv1s = Array{Float64}((length(namepairs), 3))
hsv2s = similar(hsv1s)
for (ii, (n1, n2)) in enumerate(namepairs)
    hsv1s[ii, :] = query(noml, [n1])
    hsv2s[ii, :] = query(noml, [n2])
end
selectperm(ColoringNames.hsv_squared_error(hsv1s, hsv2s), 1:5)

In [None]:
namepairs[[12, 15, 13, 14, 16]]

In [None]:
plot_query(mdl, "pink-purple") |> display
plot_query(mdl, "purple-pink") |> display
plot_query(mdl, "brown-orange") |> display
plot_query(mdl, "orange-brown") |> display
plot_query(mdl, "orange-yellow") |> display
plot_query(mdl, "yellow-orange") |> display
plot_query(mdl, "orange-yellow") |> display
plot_query(mdl, "yellow-orange") |> display

In [None]:
'-' .∈ collect(names) |> any