In [1]:
using WordNet
using WordEmbeddings, SoftmaxClassifier
using Utils
using Query
using Distances
using Iterators
using LightXML
using JLD

In [2]:
const WN_PATH="data/corpora/WordNet-2.1/"
db = DB(WN_PATH)

WordNet.DB

In [3]:
#ee = restore("models/ss/tokenised_lowercase_WestburyLab.wikicorp.201004_100_i1.jld");
ee = load("models/ss/tokenised_lowercase_WestburyLab.wikicorp.201004_100_i1.jld", "ee")

#ee = restore("models/ss/keep/tokenised_lowercase_WestburyLab.wikicorp.201004_50__m170000000.model");

In [4]:
function poormans_tokenize(source)
    cleaned = filter(s->!ispunct(s), lowercase(source))
    map(string,  split(cleaned))::Vector{SubString{String}}
end

function punctuation_space_tokenize(source)
    preprocced = lowercase(source)
    pass1=replace(preprocced,r"[[:punct:]]*[[:space:]][[:punct:]]*"," ")
    pass2=replace(pass1,r"[[:punct:]]*$|^[[:punct:]]*","")
    split(pass2)
end



punctuation_space_tokenize (generic function with 1 method)

In [5]:
function load_sense_key_map(wn_basedir)
    path=joinpath(wn_basedir, "dict","index.sense")
    sense_key_map = Dict{Tuple{Int64,String},String}()
    for line in eachline(path)
        full_sense_key, offset_str, sense_num_str, tagcount_str = split(line) 
        lemma_name = first(split(full_sense_key,'%'))
        sense_offset = parse(Int64, offset_str)
        index = (sense_offset,lemma_name)
        @assert(!haskey(sense_key_map, index))
        sense_key_map[index] = full_sense_key
    end
    sense_key_map
end


function sense_key(lem::Lemma, ss::Synset,sense_key_map)
    sense_key_map[(ss.offset,lem.word)]
end
function sense_keys(lem::Lemma,sense_key_map)
    [sense_key_map[(ss_offset,lem.word)] for ss_offset in lem.synset_offsets ]
end

sense_keys (generic function with 1 method)

In [6]:
sense_key_map = load_sense_key_map(WN_PATH);

In [7]:
sense_keys(db['a',"six"],sense_key_map)

1-element Array{String,1}:
 "six%5:00:00:cardinal:00"

In [8]:
sense_key_map[2134693,"discover"]

"discover%2:39:03::"

In [9]:
immutable WsdChallenge
    id::String
    word::String
    lemma::String
    pos::Char
    context::Vector{String}
end


In [10]:
function load_challenges_semeval2007(xml_file="data/corpora/wsd/semeval2007_t7/test/eng-coarse-all-words.xml")
    xdoc = parse_file(xml_file)
    xroot = root(xdoc)
    Task() do
        for text_node in child_elements(xroot)
            #@show text_node
            #text = child_nodes(text_node) |> collect
            #println(text)
            for sentence_node in child_elements(text_node)
                sentence = punctuation_space_tokenize(content(sentence_node))

                for lemma_node in child_elements(sentence_node)
                    word = content(lemma_node) |> lowercase
                    lemma = attribute(lemma_node,"lemma")|> lowercase
                    pos = first(attribute(lemma_node,"pos"))
                    id = attribute(lemma_node,"id")
                    context = filter(w->w!=word, sentence)
                    produce(WsdChallenge(id, word, lemma, pos, context))
                end
            end
        end
    end
end
challenges = load_challenges_semeval2007() |> collect;

In [11]:
immutable WsdSolution
    id
    lemma
    pos
    solutions
end
only_of_pos(data, pos) = filter(d->d.pos==pos, data)

function load_solutions(key_file="data/corpora/wsd/semeval2007_t7/key/dataset21.test.key")
    map(eachline(key_file)) do line
        line_data, comment = split(line,"!!")
        fields = split(line_data)
        doc_id = fields[1]
        instance_id = fields[2]
        solutions = fields[3:end]
        
        lemma,pos = match(r"lemma=(.*)#(.)", comment).captures
        WsdSolution(instance_id, lemma, pos[1], solutions)
    end
end

solutions = load_solutions();

In [12]:
solutions[2]

WsdSolution("d001.s001.t002","Ill",'a',SubString{String}["ill%3:00:01::"])

In [13]:
function sense_frequency(ss::Synset, lem::Lemma)
    lem_count = lem.tagsense_count
    #lem_count+0.1(sum(values(ss.word_counts))-lem_count) #Do not use Synset's other counts. It owrks out worse
    lem_count
end

sense_frequency (generic function with 1 method)

In [14]:
function normal_probs(logprobs::Vector)
    ret = copy(logprobs)
    max_lp = maximum(logprobs)
    ret.-=max_lp #Bring closer to zero
    map!(exp,ret)
    denom = sum(ret)
    ret./=denom
    ret
end
function weighted_average(logprobs, embeddings)
    ret = zeros(first(embeddings))
    for (weight, embedding) in zip(normal_probs(logprobs), embeddings)
        ret.+= weight.*embedding
    end
    ret
end


function synthesize_embedding(ee,context, phrase)
    words = split(phrase," ")
    wvs = vcat((get(ee.embedding,w,Vector{Float32}[]) for w in words)...)
    if length(wvs) == 0
            throw(KeyError("None of $words have embeddings"))
    end
    logprobs = [Query.logprob_of_context(ee, context,wv; skip_oov=true, normalise_over_length=true) for wv in wvs]
    weighted_average(logprobs, wvs)
end


all_identical(col) = length(col)==1 || !any(x->x!=col[1],col[2:end])

lexically_informed_embeddings(db,ee, lemma_word, pos) = lexically_informed_embeddings(db,ee, lemma_word, lemma_word, pos)
function lexically_informed_embeddings(db,ee, word,lemma_word, pos)
    
    lemma = db[pos, lemma_word]
    target_synsets::Vector{Synset} = synsets(db, lemma)
         
    embeddings = Vector{Vector{Float32}}(length(target_synsets))
    for (ii,synset) in enumerate(target_synsets)
        context::Vector{SubString{String}} = collect(filter!(w->w!=word, punctuation_space_tokenize(synset.gloss)))
        embeddings[ii] = synthesize_embedding(ee,context, word)
    end
        
    
    if all_identical(embeddings)
        #Use MCWS
        score, index = findmax(sense_frequency(ss,lemma) for ss in target_synsets)
        target_synsets=[target_synsets[index]] # Throw out others
        embeddings=[embeddings[1]]
    end
        
    embeddings,target_synsets,lemma
end

lexically_informed_embeddings (generic function with 2 methods)

In [15]:
function supervised_wsd(challenge::WsdChallenge, ee::WordSenseEmbedding, wn::DB, sense_key_map)
    try
        embeddings,target_synsets,lemma = lexically_informed_embeddings(wn,ee,
                                                            challenge.word, 
                                                            challenge.lemma, challenge.pos)
        #sense_index = Query.WSD(ee, embeddings, challenge.context; skip_oov=true)
        logprobs_of_context = [logprob_of_context(ee,challenge.context, wv; skip_oov=true, normalise_over_length=true) 
                                    for wv in embeddings] 
      
        sense_index= indmax(logprobs_of_context)

        
        synset = target_synsets[sense_index]
        sk = sense_key(lemma, synset, sense_key_map)
        Nullable(sk)
    catch ex
        isa(ex, KeyError) || rethrow(ex)
        Nullable{String}()
    end
end

function most_frequent_sense(challenge, wn::DB, sense_key_map)
    try
        lemma = wn[challenge.pos, challenge.lemma]
        target_synsets::Vector{Synset} = synsets(db, lemma)
        
        sense_freqs =  Float32[sense_frequency(ss,lemma) for ss in target_synsets]      
        sense_index= indmax(sense_freqs)
        synset = target_synsets[sense_index]
        
        sk = sense_key(lemma, synset, sense_key_map)
        Nullable(sk)
    catch ex
        isa(ex, KeyError) || rethrow(ex)
        Nullable{String}()
    end
end



most_frequent_sense (generic function with 1 method)

In [16]:
function evalute_on_wsd_challenge(challenges, solutions, method)
    correct = 0
    incorrect = 0
    notattempted = 0
    for (challenge, ground_solution) in zip(challenges, solutions)
        assert(challenge.id == ground_solution.id)
        output_sense = method(challenge)
        if isnull(output_sense)
            notattempted+=1
        else
            if get(output_sense) ∈ ground_solution.solutions
                correct+=1
            else
                incorrect+=1
            end
        end
    end
    precision = correct/(correct+incorrect)
    recall = correct/(correct+incorrect+notattempted)
    f1 = (2*precision*recall) / (precision+recall)
    return precision, recall, f1
end
    
    
function evalute_on_wsd_challenge(challenges, solutions, ee::WordSenseEmbedding, wn::DB, sense_key_map)
    method = challenge -> supervised_wsd(challenge, ee, wn, sense_key_map)
    evalute_on_wsd_challenge(challenges, solutions, method)
end


function evalute_on_wsd_challenge_MFS(challenges, solutions, wn::DB, sense_key_map)
    method = challenge -> most_frequent_sense(challenge, wn, sense_key_map)
    evalute_on_wsd_challenge(challenges, solutions, method)
end

evalute_on_wsd_challenge_MFS (generic function with 1 method)

In [17]:
println("overall: \t\t", evalute_on_wsd_challenge(challenges, solutions, ee, db, sense_key_map))

println("only nouns  : \t\t", evalute_on_wsd_challenge(only_of_pos(challenges,'n'), only_of_pos(solutions,'n'), ee, db, sense_key_map))
println("only verbs  : \t\t", evalute_on_wsd_challenge(only_of_pos(challenges,'v'), only_of_pos(solutions,'v'), ee, db, sense_key_map))
println("only adjecti: \t\t", evalute_on_wsd_challenge(only_of_pos(challenges,'a'), only_of_pos(solutions,'a'), ee, db, sense_key_map))
println("only adverbs: \t\t", evalute_on_wsd_challenge(only_of_pos(challenges,'r'), only_of_pos(solutions,'r'), ee, db, sense_key_map))


overall: 		(0.6486486486486487,0.6346408109299251,0.6415682780129205)
only nouns  : 		(0.6593001841620626,0.6462093862815884,0.6526891522333637)
only verbs  : 		(0.5785837651122625,0.5668358714043993,0.5726495726495726)
only adjecti: 		(0.6762177650429799,0.6519337016574586,0.6638537271448665)
only adverbs: 		(0.7427184466019418,0.7355769230769231,0.7391304347826086)


In [18]:
println("overall: \t\t", evalute_on_wsd_challenge_MFS(challenges, solutions, db, sense_key_map))

println("only nouns  : \t\t", evalute_on_wsd_challenge_MFS(only_of_pos(challenges,'n'), only_of_pos(solutions,'n'), db, sense_key_map))
println("only verbs  : \t\t", evalute_on_wsd_challenge_MFS(only_of_pos(challenges,'v'), only_of_pos(solutions,'v'), db, sense_key_map))
println("only adjecti: \t\t", evalute_on_wsd_challenge_MFS(only_of_pos(challenges,'a'), only_of_pos(solutions,'a'), db, sense_key_map))
println("only adverbs: \t\t", evalute_on_wsd_challenge_MFS(only_of_pos(challenges,'r'), only_of_pos(solutions,'r'), db, sense_key_map))


overall: 		(0.7783164389598942,0.7783164389598942,0.7783164389598942)
only nouns  : 		(0.7653429602888087,0.7653429602888087,0.7653429602888087)
only verbs  : 		(0.7529610829103215,0.7529610829103215,0.7529610829103215)
only adjecti: 		(0.8038674033149171,0.8038674033149171,0.8038674033149171)
only adverbs: 		(0.875,0.875,0.875)


In [None]:
#Zero shot WSI for 1 shot WSD
using Training

In [None]:
rr = FixedWordSenseEmbedding(ee.dimension, random_inited, huffman_tree; initial_nsenses=1)
rr.distribution = ee.distribution
rr.codebook = ee.codebook
rr.classification_tree = ee.classification_tree
Training.initialize_embedding(rr)

In [None]:
rr.initial_nsenses=100
Training.initialize_embedding(rr)

In [None]:
println("overall: \t\t", evalute_on_wsd_challenge(challenges, solutions, rr, db, sense_key_map))

println("only nouns  : \t\t", evalute_on_wsd_challenge(only_of_pos(challenges,'n'), only_of_pos(solutions,'n'), rr, db, sense_key_map))
println("only verbs  : \t\t", evalute_on_wsd_challenge(only_of_pos(challenges,'v'), only_of_pos(solutions,'v'), rr, db, sense_key_map))
println("only adjecti: \t\t", evalute_on_wsd_challenge(only_of_pos(challenges,'a'), only_of_pos(solutions,'a'), rr, db, sense_key_map))
println("only adverbs: \t\t", evalute_on_wsd_challenge(only_of_pos(challenges,'r'), only_of_pos(solutions,'r'), rr, db, sense_key_map))


In [None]:
hh = deepcopy(rr)
for word in keys(rr.embedding)
    append!(hh.embedding[word], ee.embedding[word])
end

In [None]:
println("overall: \t\t", evalute_on_wsd_challenge(challenges, solutions, hh, db, sense_key_map))

println("only nouns  : \t\t", evalute_on_wsd_challenge(only_of_pos(challenges,'n'), only_of_pos(solutions,'n'), hh, db, sense_key_map))
println("only verbs  : \t\t", evalute_on_wsd_challenge(only_of_pos(challenges,'v'), only_of_pos(solutions,'v'), hh, db, sense_key_map))
println("only adjecti: \t\t", evalute_on_wsd_challenge(only_of_pos(challenges,'a'), only_of_pos(solutions,'a'), hh, db, sense_key_map))
println("only adverbs: \t\t", evalute_on_wsd_challenge(only_of_pos(challenges,'r'), only_of_pos(solutions,'r'), hh, db, sense_key_map))

In [None]:
using JLD