In [1]:
using Plots
pyplot()
using ColoringNames
using CSVFiles
using DataFrames
pyplot()

runnum = "0a"
function mklogdir(name)
    logdir = joinpath(pwd(), "logs","point", name * runnum)
    mkpath(logdir)
    logdir
end

mklogdir (generic function with 1 method)

In [6]:
const many_names = load_color_nameset()
const word_embs, vocab, enc = load_text_embeddings(keep_words=many_names)
const full_data = load_munroe_data(dev_as_train=false, dev_as_test=false, encoding_ = enc)

const ord_data = order_relevant_dataset(full_data);
const extrapo_data = extrapolation_dataset(full_data);



In [49]:
function get_saveplot(mdlname)
    figpath = joinpath("demo", "point", mdlname)
    mkpath(figpath)
    function saveplot(mdl, colornames, subfigname)
        filename = joinpath(figpath, subfigname*".png")
        savefig(plot_query(mdl, colornames), filename)
    end
end


function qualitative_demo(mdl, do_oov=true;
        demofun= (mdl, colornames, subfigname) -> display(plot_query(mdl, colornames))
    )
    
    maincolors = [   "brown-orange",
                "orange-brown",
                "yellow-orange",
                "orange-yellow",
                "brownish green",
                "greenish brown",
                "bluish grey",
                "greyish blue",
                "pink-purple",
                "purple-pink",
                "green",
                "greenish",
                "purple",
                "purplish",
                "brown",
                "brownish",
                "black",
                "white",
                "grey"]
    demofun(mdl, maincolors, "maincolors")
    
    if do_oov
        oov_names = ["Brown", "Green", "Purple", "gray", "Gray"]
        #append!(names, ["ish"])
        demofun(mdl, oov_names, "oovcolors")
    end
    
end

qualitative_demo(mdl::TermToColorPointEmpirical; kwargs...) = qualitative_demo(mdl, false; kwargs...)

qualitative_demo (generic function with 3 methods)

In [50]:
create_res_df() = DataFrame(method=String[], mse=Float64[])

full_df = create_res_df()
extrapo_df = create_res_df()
ord_df = create_res_df()

function perform_evaluation(modeltype, name)
    info(name)
    
    function mdlfun(cldata)
        mdl = modeltype(enc, word_embs; n_steps=size(cldata.train.terms_padded,1))
        train!(mdl, cldata; log_dir = mklogdir(name))
        mdl
    end
    
    
    ###################
    if doextrapo(modeltype)
        extrapo_mdl = mdlfun(extrapo_data)
        println()
        @show res_extrapo = evaluate(extrapo_mdl, extrapo_data.test)
        push!(extrapo_df, Dict(:method=>name * " Extrapolating", :mse=>res_extrapo))
    end
    
    
    ######
    full_mdl = mdlfun(full_data)
    println()
    @show res_full = evaluate(full_mdl, full_data.test)
    push!(full_df, Dict(:method=>name, :mse=>res_full))
    
    qualitative_demo(full_mdl; demofun=get_saveplot(name))
    
    #######
    println()
    @show res_ord = evaluate(full_mdl, ord_data.test)
    push!(ord_df, Dict(:method=>name, :mse=>res_ord))
    
    
    #######
    @show  res_nonextrapo = evaluate(full_mdl, extrapo_data.test)
    println()
    push!(extrapo_df, Dict(:method=>name * " Non-extrapolating", :mse=>res_nonextrapo))
    
    
    ####
    save("results/point_full.csv", full_df)
    save("results/point_extrapo.csv", extrapo_df)
    save("results/point_ord.csv", ord_df)
    
    full_mdl
end


doextrapo(::Any) = true
doextrapo(::Type{TermToColorPointEmpirical}) = false

doextrapo (generic function with 2 methods)

In [None]:
namedmodels = [
    (TermToColorPointEmpirical, "Direct"),
    (TermToColorPointSOWE, "SOWE"),
    (TermToColorPointCNN, "CNN"),
    (TermToColorPointRNN, "RNN"),    
]

for (modeltype, name) in namedmodels
    perform_evaluation(modeltype, name)
end;

[1m[36mINFO: [39m[22m[36mDirect
[39m


res_full = evaluate(full_mdl, full_data.test) = 0.06635612f0

res_ord = evaluate(full_mdl, ord_data.test) = 0.064801306f0
res_nonextrapo = evaluate(full_mdl, extrapo_data.test) = 0.062132385f0



[1m[36mINFO: [39m[22m[36mSOWE
[39m2018-06-19 20:00:02.174178: I tensorflow/core/platform/cpu_feature_guard.cc:140] Your CPU supports instructions that this TensorFlow binary was not compiled to use: SSE4.1 SSE4.2



res_extrapo = evaluate(extrapo_mdl, extrapo_data.test) = 0.07881486f0





res_full = evaluate(full_mdl, full_data.test) = 0.06674081f0

res_ord = evaluate(full_mdl, ord_data.test) = 0.06605732f0
res_nonextrapo = evaluate(full_mdl, extrapo_data.test) = 0.06493569f0



[1m[36mINFO: [39m[22m[36mCNN
[39m

In [None]:
#noml = ColoringNames.TermToColorPointEmpirical()
#train!(noml, full_data)

qualitative_demo(noml, demofun=get_saveplot("Direct"))


In [None]:
noml = ColoringNames.TermToColorPointEmpirical()
train!(noml, cldata)

qualitative_demo(noml)
@show evaluate(noml, cldata.dev)

In [None]:
sowe = TermToColorPointSOWE(full_data.encoding, word_embs; n_steps=4)
train!(sowe, full_data;
    log_dir=mklogdir("sowe")
)
qualitative_demo(sowe)
@show evaluate(sowe, full_data.dev)

In [None]:
cnn = TermToColorPointCNN(cldata.encoding, word_embs; n_steps=4)

train!(cnn, cldata;
    log_dir=mklogdir("cnn"),
)
qualitative_demo(cnn)
@show evaluate(cnn, cldata.dev)

In [None]:
qualitative_demo(cnn)
@show evaluate(cnn, cldata.dev)

In [None]:
rnn = TermToColorPointRNN(cldata.encoding, word_embs; n_steps=4)

train!(rnn, cldata;
    log_dir=mklogdir("rnn"),
)
qualitative_demo(rnn)
@show evaluate(rnn, cldata.dev)

In [None]:
1

In [None]:
using ColoringNames: order_relevant_dataset, order_relevant_name_pairs

In [None]:
namepairs = order_relevant_name_pairs(cldata.dev);

In [None]:
hsv1s = Array{Float64}((length(namepairs), 3))
hsv2s = similar(hsv1s)
for (ii, (n1, n2)) in enumerate(namepairs)
    hsv1s[ii, :] = query(noml, [n1])
    hsv2s[ii, :] = query(noml, [n2])
end
selectperm(ColoringNames.hsv_squared_error(hsv1s, hsv2s), 1:5)

In [None]:
namepairs[[12, 15, 13, 14, 16]]

In [None]:
plot_query(mdl, "pink-purple") |> display
plot_query(mdl, "purple-pink") |> display
plot_query(mdl, "brown-orange") |> display
plot_query(mdl, "orange-brown") |> display
plot_query(mdl, "orange-yellow") |> display
plot_query(mdl, "yellow-orange") |> display
plot_query(mdl, "orange-yellow") |> display
plot_query(mdl, "yellow-orange") |> display

In [None]:
'-' .∈ collect(names) |> any