Skip to content

Commit

Permalink
Tests for indexing and other refinements
Browse files Browse the repository at this point in the history
  • Loading branch information
zgornel committed Sep 20, 2018
1 parent 78cbae6 commit 1f86144
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 31 deletions.
20 changes: 10 additions & 10 deletions src/files.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,22 +20,22 @@ end


# Function that loads the embeddings given a valid ConceptNetNumberbatch file
function load_embeddings(file::AbstractString;
function load_embeddings(filepath::AbstractString;
max_vocab_size::Union{Nothing,Int}=nothing,
keep_words=String[],
language=:unknown)
if any(endswith.(file, [".gz", ".gzip"]))
conceptnet = _load_gz_embeddings(file,
if any(endswith.(filepath, [".gz", ".gzip"]))
conceptnet = _load_gz_embeddings(filepath,
GzipDecompressor(),
max_vocab_size,
keep_words,
language=language)
elseif any(endswith.(file, [".h5", ".hdf5"]))
conceptnet = _load_hdf5_embeddings(file,
elseif any(endswith.(filepath, [".h5", ".hdf5"]))
conceptnet = _load_hdf5_embeddings(filepath,
max_vocab_size,
keep_words)
else
conceptnet = _load_gz_embeddings(file,
conceptnet = _load_gz_embeddings(filepath,
Noop(),
max_vocab_size,
keep_words,
Expand All @@ -46,7 +46,7 @@ end


# Loads the ConceptNetNumberbatch from a .gz or uncompressed file
function _load_gz_embeddings(file::S1,
function _load_gz_embeddings(filepath::S1,
decompressor::TranscodingStreams.Codec,
max_vocab_size::Union{Nothing,Int},
keep_words::Vector{S2};
Expand All @@ -56,7 +56,7 @@ function _load_gz_embeddings(file::S1,
_length::Int, _width::Int
type_words = Vector{String}
type_matrix = Matrix{Float64}
open(file, "r") do fid
open(filepath, "r") do fid
cfid = TranscodingStream(decompressor, fid)
_length, _width = parse.(Int64, split(readline(cfid)))
embeddings_words = type_words(undef, _length)
Expand Down Expand Up @@ -96,13 +96,13 @@ end


# Loads the ConceptNetNumberbatch from a HDF5 file
function _load_hdf5_embeddings(file::S1,
function _load_hdf5_embeddings(filepath::S1,
max_vocab_size::Union{Nothing,Int},
keep_words::Vector{S2}) where
{S1<:AbstractString, S2<:AbstractString}
type_words = Vector{String}
type_matrix = Matrix{Int8}
payload = h5open(read, file)["mat"]
payload = h5open(read, filepath)["mat"]
words = payload["axis1"]
embeddings = payload["block0_values"]
vocab_size = _get_vocab_size(length(words),
Expand Down
33 changes: 12 additions & 21 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ ConceptNet(words::W, embeddings::E) where
const ConceptNetMulti = ConceptNet{:multi, Vector{String}, Matrix{Float64}}
const ConceptNetMultiCompressed = ConceptNet{:multi_c, Vector{String}, Matrix{Int8}}
const ConceptNetEnglish = ConceptNet{:en, Vector{String}, Matrix{Float64}}
const ConceptNetUnknown = ConceptNet{:ubknown, Vector{String}, Matrix{Float64}}
const ConceptNetUnknown = ConceptNet{:unknown, Vector{String}, Matrix{Float64}}


# Show methods
Expand All @@ -32,35 +32,26 @@ show(io::IO, conceptnet::ConceptNetUnknown) = begin
print(io, "ConceptNet (Unknown language) with $(length(conceptnet.words)) embeddings")
end


# getindex methods
function getindex(conceptnet::ConceptNetMultiCompressed, words::S) where
{S<:AbstractVector{<:AbstractString}}
@warn "Results may be wrong!"
return conceptnet.embeddings[:, findall((in)(words), conceptnet.words)]
end

function getindex(conceptnet::ConceptNetMulti, words::S) where
{S<:AbstractVector{<:AbstractString}}
@warn "Results may be wrong!"
return conceptnet.embeddings[:, findall((in)(words), conceptnet.words)]
end

function getindex(conceptnet::ConceptNetEnglish, words::S) where
# TODO(Corneliu):
# - specific implementation for multilanguage files (w. language detection)
# - add OOV - pre-processing functions
function getindex(conceptnet::ConceptNet, words::S) where
{S<:AbstractVector{<:AbstractString}}
lenemb = size(conceptnet.embeddings, 1)
embeddings = zeros(eltype(conceptnet.embeddings), lenemb, length(words))
indices = indexin(conceptnet.words, words)
for idx in indices
indices = indexin(words, conceptnet.words)
for (i, idx) in enumerate(indices)
if idx != nothing
embeddings[:,idx] = conceptnet.embeddings[:, idx]
embeddings[:,i] = conceptnet.embeddings[:, idx]
end
end
return embeddings
end

getindex(::ConceptNetUnknown, words::S) where {S<:Vector{<:AbstractString}} =
@error "Indexing not supported for an :unknown language ConceptNet"
getindex(::ConceptNetUnknown, words::S) where {S<:Vector{<:AbstractString}} = begin
throw(ArgumentError("Indexing not supported for an :unknown language ConceptNet"))
end


getindex(conceptnet::ConceptNet, word::S where S<:AbstractString)= conceptnet[[word]]

Expand Down
24 changes: 24 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,30 @@ max_vocab_size=5
end
end

@testset "getindex, size, length" begin
filepath = joinpath(string(@__DIR__), "data", "_test_file_en.txt.gz")
# Known language
conceptnet, _, _ = load_embeddings(filepath, language=:en)
words = ["####_ish", "####_form", "####_metres", "not_found", "not_found2"]
embeddings = conceptnet[words]
for (idx, word) in enumerate(words)
if word in conceptnet.words
@test embeddings[:,[idx]] ==
conceptnet.embeddings[:, indexin([word], conceptnet.words)]
else
@test iszero(embeddings[:,idx])
end
end
# Unknown language
conceptnet, _, _ = load_embeddings(filepath) # unknown language
@test_throws ArgumentError conceptnet[words]
# length
@test length(conceptnet) == length(conceptnet.words)
# size
@test size(conceptnet) == size(conceptnet.embeddings)
end


# show methods
@testset "Show methods" begin
buf = IOBuffer()
Expand Down

0 comments on commit 1f86144

Please sign in to comment.