In [11]:
using WordNet
using WordEmbeddings, SoftmaxClassifier
using Utils
using Query
using Distances
using Iterators
using NearestNeighbors
using JLD
using Trees
using AbstractTrees
using BlossomV

In [7]:
ee = load("../eval/models/plain/tokenised_lowercase_WestburyLab.wikicorp.201004_50__i1.jld","ee");


In [9]:
using Lumberjack
remove_truck("console")
add_truck(LumberjackTruck(STDOUT, "info", Dict{Symbol,Any}(:is_colorized => true)), "console")
add_truck(LumberjackTruck("semhuff_ipynb.log"), "file-logger")

Lumberjack.LumberjackTruck(IOStream(<file semhuff_ipynb.log>),nothing,Dict{Any,Any}(Pair{Any,Any}(:is_colorized,false),Pair{Any,Any}(:uppercase,false)))

In [10]:
#run(`tail semhuff_ipynb.log`)

In [22]:
import BlossomV.dense_num_edges
"""
dist_func must resturn a values between 0.0 and 1.0. 
result is a matrix where each column is a pairing of items, expressed as there index (1 indexed)
"""
function most_similar_pairings(dist_func::Function, items, consider_nearest_n::Integer)
    m = Matching(length(items), min(dense_num_edges(length(items)), length(items)*consider_nearest_n))
    #sims = Matrix{Int32}(length(items),length(items)).*Inf #for Debug purposes
    for ii in (1:length(items)-1)
        jjs = ii+1:length(items)
        dists = [dist_func(items[ii], items[jj]) for jj in jjs]
        nearest_jjs = if consider_nearest_n < length(jjs)
            jjs[selectperm(dists, 1:consider_nearest_n)]
        else
            jjs
        end
            
        for (jj,dist) in zip(nearest_jjs,dists)
            @assert dist<=1.0
            scale = typemax(Int32)>>4
            approx_dist = round(Int32, dist*scale)
            #sims[ii,jj]=approx_dist
            #println(join([m, ii-1, jj-1, approx_dist],"\t"))
            add_edge(m, ii-1,jj-1 , approx_dist)
        end
    end
    solve(m)
    get_all_matches(m, length(items)) .+ 1    
end



most_similar_pairings

In [26]:
"Returns a new tree, with the same structure but different values for the data"
function transform_tree(node::BranchNode; leaf_transform=identity, internal_transform=identity)
    function treemap_inner(node, leaf_transform, internal_transform)
        data = Trees.isleaf(node) ? leaf_transform(node.data) : internal_transform(node.data)
        children = BranchNode[treemap_inner(child, leaf_transform, internal_transform)  for child in node.children]
        new_node = BranchNode(copy(children), data)
    end
    treemap_inner(node, leaf_transform, internal_transform)
end
    

transform_tree

In [27]:
#TODO make this a testcase
x = BranchNode(
[BranchNode([],"11"),BranchNode([
        BranchNode([],"121"),BranchNode([],"122"),
        ],"12")],
    "1"
)
print_tree(STDOUT, x)

println()
y=transform_tree(x, leaf_transform = word->"L"*word, internal_transform = dummy -> "x"*dummy)
print_tree(STDOUT, y)

Trees.BranchNode with 2 children. data = "1"
├─ Trees.BranchNode with 0 children. data = "11"
└─ Trees.BranchNode with 2 children. data = "12"
   ├─ Trees.BranchNode with 0 children. data = "121"
   └─ Trees.BranchNode with 0 children. data = "122"

Trees.BranchNode with 2 children. data = "x1"
├─ Trees.BranchNode with 0 children. data = "L11"
└─ Trees.BranchNode with 2 children. data = "x12"
   ├─ Trees.BranchNode with 0 children. data = "L121"
   └─ Trees.BranchNode with 0 children. data = "L122"


In [28]:
"""
Assumes that the `tree` is already a Huffman tree.
"""
function semhuff(classification_tree, embeddings, consider_nearest_n)
    embedding_dim = length(first(embeddings))
    sim_tree = transform_tree(classification_tree, 
                            leaf_transform = word->word,
                            internal_transform = dummy -> "")
    
    nodes_at_depth, codes_at_depth = levels(sim_tree)
    maxdepth = length(nodes_at_depth)
    
    
    embeds = [embeddings[node.data] for node in nodes_at_depth[maxdepth]]
    #Dict(code => embeddings[node.data] for (code, node) in zip(codes_at_depth[maxdepth], nodes_at_depth[maxdepth]))
    for depth in maxdepth:-1:2
        info("semantically sorting level: $depth")
        nodes = nodes_at_depth[depth]
        
        pair_indexes = most_similar_pairings(Query.angular_dist, embeds, consider_nearest_n)
        
        #We will now, assign the new nodes to parents in arbitary order
        nodes_above = nodes_at_depth[depth-1]
        embeds_above = typeof(embeds)(length(nodes_above))
        pair_jj = 1

        for (above_ii,node_above) in enumerate(nodes_above)

            if Trees.isleaf(node_above)
                embeds_above[above_ii] = embeddings[node_above.data]
            else
                #It is a branch so put a pair of nodes here
                @assert (length(node_above.children) == 2)
                child_index1  = pair_indexes[1,pair_jj]
                child_index2  = pair_indexes[2,pair_jj]
                pair_jj += 1
                node_above.children[1] = nodes[child_index1]
                node_above.children[2] = nodes[child_index2]
                embeds_above[above_ii] = (embeds[child_index1] + embeds[child_index2])/2.0
                #@show nodes[child_index1], nodes[child_index2]
            end
        end
        @assert(pair_jj - 1  == size(pair_indexes,2), "$(pair_jj) != $(length(pair_indexes)) + 1") #All pairs must be assigned.
            
        embeds = embeds_above
    end
    sim_tree
end
    
    

semhuff

In [29]:
midi_tree = ee.classification_tree
@time semtree = semhuff(midi_tree, ee.embedding, 30);
#print_tree(STDOUT, semtree)

2016-08-04T20:34:43.56 - info: semantically sorting level: 22
perfect matching with 6518 nodes and 195075 edges
    starting init...done [0.463 secs]. 10 trees
    .8.6.4.2.0.
done [0.537 secs]. 928 grows, 0 expands, 17 shrinks
    expands: [0.000 secs], shrinks: [0.001 secs], dual updates: [0.000 secs]
2016-08-04T20:34:56.502 - info: semantically sorting level: 21
perfect matching with 32324 nodes and 969255 edges
    starting init...done [2.473 secs]. 68 trees
    .64.56.48.40.32.24.16.8.6.4.2.0.
done [2.825 secs]. 7170 grows, 4 expands, 99 shrinks
    expands: [0.002 secs], shrinks: [0.002 secs], dual updates: [0.002 secs]
2016-08-04T20:40:29.59 - info: semantically sorting level: 20
perfect matching with 34866 nodes and 1045515 edges
    starting init...done [2.809 secs]. 78 trees
    .64.56.48.40.32.24.16.8.6.4.2.0.
done [3.091 secs]. 6635 grows, 4 expands, 117 shrinks
    expands: [0.002 secs], shrinks: [0.004 secs], dual updates: [0.002 secs]
2016-08-04T20:46:39.037 - info: sema

Trees.BranchNode with 2 children. data = ""

In [42]:
print_tree(STDOUT, semtree[1][1][1][1][1][1][1][1][2])

Trees.BranchNode with 2 children. data = ""
├─ Trees.BranchNode with 2 children. data = ""
│  ├─ Trees.BranchNode with 2 children. data = ""
│  │  ├─ Trees.BranchNode with 2 children. data = ""
│  │  │  ├─ Trees.BranchNode with 2 children. data = ""
│  │  │  │  ├─ Trees.BranchNode with 0 children. data = "patrick"
│  │  │  │  └─ Trees.BranchNode with 2 children. data = ""
│  │  │  │     ├─ Trees.BranchNode with 2 children. data = ""
│  │  │  │     │  ├─ Trees.BranchNode with 2 children. data = ""
│  │  │  │     │  │  ├─ Trees.BranchNode with 0 children. data = "robbins"
│  │  │  │     │  │  └─ Trees.BranchNode with 0 children. data = "baxter"
│  │  │  │     │  └─ Trees.BranchNode with 2 children. data = ""
│  │  │  │     │     ├─ Trees.BranchNode with 2 children. data = ""
│  │  │  │     │     │  ├─ Trees.BranchNode with 0 children. data = "goddard"
│  │  │  │     │     │  └─ Trees.BranchNode with 2 children. data = ""
│  │  │  │     │     │     ├─ Trees.BranchNode with 0 children. dat

In [None]:
open("tree.txt","w") do fp
    print_tree(fp, semtree)
end

In [45]:
import WordEmbeddings: NetworkType

type SemHuff <: NetworkType
    source::GenWordEmbedding
end

LoadError: LoadError: invalid redefinition of constant SemHuff
while loading In[45], in expression starting on line 3

In [52]:
using Training
using PooledElements



search: initialize_embedding



No documentation found.

`Training.initialize_embedding` is a `Function`.

```
# 4 methods for generic function "initialize_embedding":
initialize_embedding(embed::WordEmbeddings.FixedWordSenseEmbedding, ::WordEmbeddings.RandomInited) at /mnt_volume/phd/prototypes/SenseSplittingWord2Vec/src/Training/fixed_sense_embedding_training.jl:70
initialize_embedding(embed::WordEmbeddings.WordEmbedding, randomly::WordEmbeddings.RandomInited) at /mnt_volume/phd/prototypes/SenseSplittingWord2Vec/src/Training/word_embedding_training.jl:85
initialize_embedding(embed::WordEmbeddings.WordSenseEmbedding, ::WordEmbeddings.RandomInited) at /mnt_volume/phd/prototypes/SenseSplittingWord2Vec/src/Training/sense_embedding_training.jl:224
initialize_embedding(embed::WordEmbeddings.GenWordEmbedding) at /mnt_volume/phd/prototypes/SenseSplittingWord2Vec/src/Training/general_training.jl:66
```


In [54]:
function initialize_network!(embed::GenWordEmbedding, network_type::SemHuff)
    source_tree = network_type.source.classification_tree
    source_embeddings = network_type.source.embedding
    
    debug("Began SemHuff sorting")
    semtree = semhuff(source_tree, source_embeddings, 30);
    debug("Completed SemHuff sorting")
    debug("Began classification tree creation")
    embed.classification_tree = transform_tree(semtree, 
                            leaf_transform = word->word,
    internal_transform = dummy -> LinearClassifier(2,embed.dim))
    
    embed.codebook = Dict(leaves_of(classification_tree))
    debug("Completed SemHuff Bootstrapping")
    embed
end

initialize_network! (generic function with 1 method)

In [None]:
sem_ee = deepcopy(ee)
sem_ee.network_type = SemHuff(ee)
sem_ee.embedding = Dict(pstring(word)=>wv for (word, wv) in ee.embedding)
initialize_embedding(sem_ee,sem_ee.init_type)
initialize_network!(sem_ee,sem_ee.network_type)

2016-08-04T23:49:05.774 - info: semantically sorting level: 22
perfect matching with 6518 nodes and 195075 edges
    starting init...done [0.211 secs]. 10 trees
    .8.6.4.2.0.
done [0.235 secs]. 928 grows, 0 expands, 17 shrinks
    expands: [0.000 secs], shrinks: [0.000 secs], dual updates: [0.000 secs]
2016-08-04T23:49:19.086 - info: semantically sorting level: 21
