In [None]:
using WordNet
using WordEmbeddings, SoftmaxClassifier
using Utils
using Query
using Distances
using Iterators
using NearestNeighbors
using JLD
using Trees
using AbstractTrees

In [None]:
ee = load("../eval/models/plain/tokenised_lowercase_WestburyLab.wikicorp.201004_50__i1.jld","ee");
#ee = restore("../eval/models/ss/keep/tokenised_lowercase_WestburyLab.wikicorp.201004_100_i1.model")
#dtree,labels = nn_using AbstractTreestree(ee)
#""

In [None]:
import BlossomV: Matching, add_edge, get_match,PerfectMatchingCtx,solve

Matching(node_num::Integer) = Matching(node_num, node_num*(node_num-1)÷2)      
Matching(node_num::Integer, edge_num_max::Integer) = Matching(Int32(node_num), Int32(edge_num_max))

add_edge(matching::PerfectMatchingCtx,
        first_node::Integer,
        second_node::Integer, cost::Integer) = add_edge(matching, 
                                                        Int32(first_node), 
                                                        Int32(second_node),
                                                        Int32(cost))


function safe_add_edge(matching::PerfectMatchingCtx,
                        first_node::Integer,
                        second_node::Integer, cost::Integer)
    first_node != second_node || error("Can not have an edge between $(first_node) and itself")
    first_node >= 0  || error("first_node less than zero (value: $(first_node)). Indexes are zero-based")
    second_node >= 0  || error("second_node less than zero (value: $(second_node)). Indexes are zero-based")
    cost >= 0  || error("Cost must be positive. edge between $(first_node) and $(second_node) is $cost")
    
    add_edge(matching, first_node, second_node, cost)
end

get_match(matching::PerfectMatchingCtx, node::Integer) = get_match(matching, Int32(node))


"1 indexed" #TODO make this zero indexed
function get_all_matchs(m::PerfectMatchingCtx, n_nodes::Integer) #HACK, you have to specidfiy the number of nodes, cos I can't pull it from the Matching object
    Task() do
        assigned = falses(n_nodes)
        for ii in 1:n_nodes
            assigned[ii] && continue
            jj = get_match(m, ii-1)+1
            @assert(!assigned[jj])
            assigned[ii] = true
            assigned[jj] = true
            
            produce(ii,jj)
        end
    end
end



In [None]:
function levels(tree)
    nodes_at_depth = Vector{Vector{BranchNode}}()
    codes_at_depth = Vector{Vector{Vector{Int64}}}()
    
    push!(nodes_at_depth, tree.children) #TODO: fix so works for partiailly empty root node
    push!(codes_at_depth, [[ii] for ii in 1:length(tree.children)])
    
    
    parent_depth = 1
    while true
        #@show parent_depth
        level_nodes = BranchNode[]
        level_codes = Vector{Int64}[]
        
        for (parent_code, parent_node) in zip(codes_at_depth[parent_depth], nodes_at_depth[parent_depth]) 
            for (ii, child_node) in enumerate(parent_node.children)
                child_code = [parent_code; ii] #Copy and append
                push!(level_nodes, child_node)
                push!(level_codes, child_code)
            end
        end
        
        if length(level_nodes) == 0
            #Done if no nodes on this level, don't need to save empties
            break
        else 
            push!(nodes_at_depth, level_nodes)
            push!(codes_at_depth, level_codes)
            parent_depth += 1
        end
    end
    
    nodes_at_depth, codes_at_depth
end


In [None]:
little_tree = ee.classification_tree[2][2][1][1][2][2][2][1][2][1][2][1][1][1][1]
nodes_at_depth, codes_at_depth = levels(little_tree);
codes_at_depth

In [None]:
"""dist_func must resturn a values between 0.0 and 1.0"""
function most_similar_pairings(dist_func::Function, items, consider_nearest_n::Integer)
    m = Matching(length(items))
    #sims = Matrix{Int32}(length(items),length(items)).*Inf #for Debug purposes
    for ii in (1:length(items)-1)
        jjs = ii+1:length(items)
        dists = [dist_func(items[ii], items[jj]) for jj in jjs]
        nearest_jjs = if consider_nearest_n < length(jjs)
            jjs[selectperm(dists, 1:consider_nearest_n)]
        else
            jjs
        end
            
        for (jj,dist) in zip(nearest_jjs,dists)
            @assert dist<=1.0
            scale = typemax(Int32)>>4
            approx_dist = round(Int32, dist*scale)
            #sims[ii,jj]=approx_dist
            #println(join([m, ii-1, jj-1, approx_dist],"\t"))
            safe_add_edge(m, ii-1,jj-1 , approx_dist)
        end
    end
    solve(m)
    collect(get_all_matchs(m, length(items)))#, sims
    
end

In [None]:
"Returns a new tree, with the same structure but different values for the data"
function treemap(node::BranchNode; leaf_transform=identity, internal_transform=identity)
    function treemap_inner(node, leaf_transform, internal_transform)
        data = Trees.isleaf(node) ? leaf_transform(node.data) : internal_transform(node.data)
        children = BranchNode[treemap_inner(child, leaf_transform, internal_transform)  for child in node.children]
        new_node = BranchNode(copy(children), data)
    end
    treemap_inner(node, leaf_transform, internal_transform)
end
    

In [None]:
#TODO make this a testcase
x = BranchNode(
[BranchNode([],"11"),BranchNode([
        BranchNode([],"121"),BranchNode([],"122"),
        ],"12")],
    "1"
)
print_tree(STDOUT, x)

println()
y=treemap(x, leaf_transform = word->"L"*word, internal_transform = dummy -> "x"*dummy)
print_tree(STDOUT, y)

In [None]:
function Base.show(io::IO, node :: BranchNode)                                           
    show(io, node.data)                                                                  
end                                                                                      
                 

In [None]:
"""
Assumes that the `tree` is already a Huffman tree.
"""
function semhuff4(classification_tree, embeddings, consider_nearest_n)
    embedding_dim = length(first(embeddings))
    sim_tree = treemap(classification_tree, 
                            leaf_transform = word->word,#embeddings[word],
                            internal_transform = dummy -> "")#fill(NaN32,embedding_dim) )
    
    @assert(!(sim_tree === classification_tree))
    nodes_at_depth, codes_at_depth = levels(sim_tree);
    maxdepth = length(nodes_at_depth)
    
    
    embeds = [embeddings[node.data] for node in nodes_at_depth[maxdepth]]
    #Dict(code => embeddings[node.data] for (code, node) in zip(codes_at_depth[maxdepth], nodes_at_depth[maxdepth]))
    for depth in maxdepth:-1:2
        nodes = nodes_at_depth[depth]
        codes = codes_at_depth[depth]
        
        
        pair_indexes = most_similar_pairings(Query.angular_dist, embeds, consider_nearest_n)
        
        #We will now, assign the new nodes to parents in arbitary order
        nodes_above = nodes_at_depth[depth-1]
        embeds_above = typeof(embeds)(length(nodes_above))
        pair_jj = 1

        for (above_ii,node_above) in enumerate(nodes_above)

            if Trees.isleaf(node_above)
                embeds_above[above_ii] = embeddings[node_above.data]
            else
                #It is a branch so put a pair of nodes here
                @assert (length(node_above.children) == 2)
                child_index1, child_index2 = pair_indexes[pair_jj]
                pair_jj += 1
                node_above.children[1] = nodes[child_index1]
                node_above.children[2] = nodes[child_index2]
                embeds_above[above_ii] = (embeds[child_index1] + embeds[child_index2])/2.0
                #@show nodes[child_index1], nodes[child_index2]
            end
        end
        @assert(pair_jj == length(pair_indexes) + 1, "$(pair_jj) != $(length(pair_indexes)) + 1") #All pairs must be assigned.
            
        embeds = embeds_above
    end
    sim_tree
end
    
    

In [None]:
midi_tree = ee.classification_tree
@time semtree = semhuff4(midi_tree, ee.embedding, 30);
#print_tree(STDOUT, semtree)

In [59]:
ee.codebook["ecig"]

LoadError: LoadError: KeyError: key "ecig" not found
while loading In[59], in expression starting on line 1

In [15]:
ee_keys = rand(1:length(ee.embedding),6)
wvs = collect(values(ee.embedding))[ee_keys]
words = collect(keys(ee.embedding))[ee_keys]

6-element Array{String,1}:
 "linderman"    
 "undercarriage"
 "insisted"     
 "toolkit"      
 "twenty-seven" 
 "presupposes"  

In [None]:
using Query
wvs = [randn(50) for i in 1:6]
m=Matching(Int32(length(wvs)),Int32(length(wvs))^2)

In [None]:
sims = Matrix{Int32}(length(wvs),length(wvs)).*Inf
for ii::Int32 in (1:length(wvs)-1)
    wv_ii=wvs[ii]
    for jj::Int32 in (ii+1:length(wvs))
        
        wv_jj=wvs[jj]
        dist = Query.angular_dist(wv_ii, wv_jj)
        
        scale = typemax(Int32)>>16
        approx_dist = round(Int32, dist*scale)
        sims[ii,jj]=approx_dist
        println(join([m, ii-1, jj-1, approx_dist],"\t"))
        add_edge(m, ii-1,jj-1 , approx_dist)
    end
end
sims

In [None]:
BlossomV.solve(m)

In [None]:
for ii in (1:6)-1
    match = get_match(m,ii)+1
    greed = indmin(sims[ii+1,:]-1)
    println(ii+1,"\t|\t", match, "\t", greed)
end

In [None]:

collect(get_all_matchs(m,6))

In [None]:
Base.getindex(node::BranchNode, idx) = Base.getindex(node.children, idx)


In [None]:
@inline function Base.exp2(x::Int64)
    if x > 1023 
        Inf 
    elseif x < -1074
        0.0
    else 
        reinterpret(Float64,x > -1023 ? ((1<<62)+(x-1)<<52) : 1<<(x+1074))
    end
end

function get_path_dist(code1,code2)
    l1 = length(code1)
    l2 = length(code2)
    lc = 0 #Common length
    for lc in 1:min(l1,l2)
        if code1[lc]!=code2[lc]
            lc-=1 #Last didn't match
            break
        end
    end
    ld1 = l1-lc
    ld2 = l2-lc
    ld1+ld2
end

get_path_sim(code1,code2) = exp2(-get_path_dist(code1,code2))

using Base.Test
@test get_path_sim([1],[1]) == 1.0
@test get_path_sim([1],[0]) == 0.25

@test get_path_sim([1,1],[1,0]) == 0.25
@test get_path_sim([1,1],[1,1]) == 1.0
@test get_path_sim([1,1],[0,1]) == 1/16
@test get_path_sim([1,1],[0,0]) == 1/16

@test get_path_sim([1],[1,1]) == 0.5
@test get_path_sim([1],[1,1,1]) == 0.25
@test get_path_sim([1,1],[1,1,1]) == 0.5
@test get_path_sim([1,0],[1,1,1]) == 0.125
@test get_path_sim([1,1,1],[1,1,1]) == 1.0


In [None]:
ee 