Skip to content

Commit

Permalink
Added pca_reduction
Browse files Browse the repository at this point in the history
  • Loading branch information
zgornel committed Jun 4, 2019
2 parents c9c9be0 + 58b2400 commit 3efa5ec
Show file tree
Hide file tree
Showing 8 changed files with 162 additions and 21 deletions.
71 changes: 52 additions & 19 deletions Manifest.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# This file is machine-generated - editing it directly is not advised

[[Arpack]]
deps = ["BinaryProvider", "Libdl", "LinearAlgebra"]
git-tree-sha1 = "07a2c077bdd4b6d23a40342a8a108e2ee5e58ab6"
uuid = "7d9fca2a-8960-54d3-9f78-7d1dccf2cb97"
version = "0.3.1"

[[Base64]]
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"

Expand All @@ -10,10 +16,10 @@ uuid = "9e28174c-4ba2-5203-b857-d8d62c4213ee"
version = "0.8.10"

[[BinaryProvider]]
deps = ["Libdl", "Pkg", "SHA", "Test"]
git-tree-sha1 = "055eb2690182ebc31087859c3dd8598371d3ef9e"
deps = ["Libdl", "SHA"]
git-tree-sha1 = "c7361ce8a2129f20b0e05a89f7070820cfed6648"
uuid = "b99e7846-7c00-51b0-8f62-c81ae34c0232"
version = "0.5.3"
version = "0.5.4"

[[Blosc]]
deps = ["BinaryProvider", "CMakeWrapper", "Compat", "Libdl"]
Expand All @@ -35,12 +41,9 @@ version = "1.1.1"

[[CMakeWrapper]]
deps = ["BinDeps", "CMake", "Libdl", "Parameters", "Test"]
git-tree-sha1 = "2b43d451639984e3571951cc687b8509b0a86c6d"
git-tree-sha1 = "16d4acb3d37dc05b714977ffefa8890843dc8985"
uuid = "d5fb7624-851a-54ee-a528-d3f3bac0b4a0"
version = "0.2.2"

[[CRC32c]]
uuid = "8bf52ea8-c179-5cab-976a-9e18b702a9bc"
version = "0.2.3"

[[CodecZlib]]
deps = ["BinaryProvider", "Libdl", "Test", "TranscodingStreams"]
Expand All @@ -60,6 +63,12 @@ git-tree-sha1 = "4c116b1b8bb103056face999b60b4547339a0c01"
uuid = "6bdbf80b-0969-53f9-8443-f41591bd656e"
version = "0.1.5"

[[DataStructures]]
deps = ["InteractiveUtils", "OrderedCollections", "Random", "Serialization", "Test"]
git-tree-sha1 = "ca971f03e146cf144a9e2f2ce59674f5bf0e8038"
uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
version = "0.15.0"

[[Dates]]
deps = ["Printf"]
uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"
Expand All @@ -73,10 +82,10 @@ deps = ["Random", "Serialization", "Sockets"]
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"

[[HDF5]]
deps = ["BinDeps", "Blosc", "CRC32c", "Distributed", "Homebrew", "Libdl", "LinearAlgebra", "Mmap", "Pkg", "Test", "WinRPM"]
git-tree-sha1 = "dd83e1e9c72e44e3a156438b552cf75dbdda722f"
deps = ["BinDeps", "Blosc", "Homebrew", "Libdl", "Mmap", "WinRPM"]
git-tree-sha1 = "e6f0c154d01faef0d0831d075aa8f279f95946da"
uuid = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f"
version = "0.11.0"
version = "0.11.1"

[[HTTPClient]]
deps = ["Compat", "LibCURL"]
Expand Down Expand Up @@ -107,10 +116,10 @@ uuid = "8ef0a80b-9436-5d2c-a485-80b904378c43"
version = "0.4.2"

[[LibCURL]]
deps = ["BinaryProvider", "Libdl", "Printf", "Test"]
git-tree-sha1 = "d051c8057512ca38a273aaa514145a0b25f24d46"
deps = ["BinaryProvider", "Libdl"]
git-tree-sha1 = "5ee138c679fa202ebe211b2683d1eee2a87b3dbe"
uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21"
version = "0.5.0"
version = "0.5.1"

[[LibExpat]]
deps = ["Compat"]
Expand Down Expand Up @@ -141,14 +150,26 @@ uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
deps = ["Base64"]
uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"

[[Missings]]
deps = ["SparseArrays", "Test"]
git-tree-sha1 = "f0719736664b4358aa9ec173077d4285775f8007"
uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
version = "0.4.1"

[[Mmap]]
uuid = "a63ad114-7e13-5084-954f-fe012c677804"

[[MultivariateStats]]
deps = ["Arpack", "LinearAlgebra", "Printf", "Random", "SparseArrays", "Statistics", "StatsBase", "Test"]
git-tree-sha1 = "cf1c990020bc4a52ff34ba2ee058b7cb677141f2"
uuid = "6f286f6a-111f-5878-ab1e-185364afe411"
version = "0.6.0"

[[OrderedCollections]]
deps = ["Random", "Serialization", "Test"]
git-tree-sha1 = "85619a3f3e17bb4761fe1b1fd47f0e979f964d5b"
git-tree-sha1 = "c4c13474d23c60d20a67b217f1d7f22a40edf8f1"
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
version = "1.0.2"
version = "1.1.0"

[[Parameters]]
deps = ["Markdown", "OrderedCollections", "REPL", "Test"]
Expand Down Expand Up @@ -185,6 +206,12 @@ uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383"
[[Sockets]]
uuid = "6462fe0b-24de-5631-8697-dd941f90decc"

[[SortingAlgorithms]]
deps = ["DataStructures", "Random", "Test"]
git-tree-sha1 = "03f5898c9959f8115e30bc7226ada7d0df554ddd"
uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c"
version = "0.3.1"

[[SparseArrays]]
deps = ["LinearAlgebra", "Random"]
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Expand All @@ -193,15 +220,21 @@ uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
deps = ["LinearAlgebra", "SparseArrays"]
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[[StatsBase]]
deps = ["DataStructures", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics"]
git-tree-sha1 = "8a0f4b09c7426478ab677245ab2b0b68552143c7"
uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
version = "0.30.0"

[[Test]]
deps = ["Distributed", "InteractiveUtils", "Logging", "Random"]
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[[TranscodingStreams]]
deps = ["Pkg", "Random", "Test"]
git-tree-sha1 = "f42956022d8084539f1d7219f632542b0ea686ce"
deps = ["Random", "Test"]
git-tree-sha1 = "a25d8e5a28c3b1b06d3859f30757d43106791919"
uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa"
version = "0.9.3"
version = "0.9.4"

[[URIParser]]
deps = ["Test", "Unicode"]
Expand Down
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,7 @@ version = "0.1.0"
ConceptnetNumberbatch = "6bdbf80b-0969-53f9-8443-f41591bd656e"
Languages = "8ef0a80b-9436-5d2c-a485-80b904378c43"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MultivariateStats = "6f286f6a-111f-5878-ab1e-185364afe411"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Word2Vec = "c64b6f0f-98cd-51d1-af78-58ae84944834"
1 change: 1 addition & 0 deletions REQUIRE
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ julia 1.0
Languages
Word2Vec
ConceptnetNumberbatch
MultivariateStats
6 changes: 5 additions & 1 deletion docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@ EmbeddingsAnalysis is a package for processing embeddings. At this point, only w

## Processing methods
The package implements the following embeddings processing algorithms:
- [Artetxe et al. "Uncovering divergent linguistic information in word embeddings with lessons for intrinsic and extrinsic evaluation", 2018](https://arxiv.org/pdf/1809.02094.pdf)
- [Artetxe et al. "Uncovering divergent linguistic information in word embeddings with lessons for intrinsic and extrinsic evaluation", 2018](https://arxiv.org/pdf/1809.02094.pdf)
- [Vikas Raunak "Simple and effective dimensionality reduction for word embeddings", NIPS 2017 Workshop](https://arxiv.org/abs/1708.03629)
and utilities:
- saving `WordVectors` objects to disk in either binary or text format
- convert `ConceptNet` objects to `WordVectors` objects

## Installation

Expand Down
7 changes: 6 additions & 1 deletion src/EmbeddingsAnalysis.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
module EmbeddingsAnalysis

using LinearAlgebra
using Statistics
using Languages
using Word2Vec
using ConceptnetNumberbatch
using MultivariateStats

import Base: dump

export conceptnet2wv,
similarity_order
similarity_order,
pca_reduction

include("dump.jl")
include("conceptnet2wv.jl")
include("similarity_order.jl")
include("pca_reduction.jl")

end # module
84 changes: 84 additions & 0 deletions src/pca_reduction.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
"""
pca_reduction(wv::WordVectors, rdim=7, outdim=size(wv.vectors,1); [do_pca=true])
Post-processes word embeddings `wv` by removing the first `rdim` PCA components
from the word vectors and also reduces the dimensionality to `outdim` through
a subsequent PCA transform, if `do_pca=true`.
# Arguments
* `wv::WordVectors` the word embeddings
* `rdim::Int` the number of PCA components to remove from the data
(default 7)
* `outdim::Int` the output dimensionality of the data after the PCA
dimensionality reduction; it is performed only if `do_pca=true`
and the default value is the same as that of the input embeddings
i.e. no reduction
# Keyword arguments
* `do_pca::Bool` whether to perform a PCA transform of the
post-processed data (default `true`)
# References:
* [Vikas Raunak "Simple and effective dimensionality reduction for
word embeddings", NIPS 2017 Workshop](https://arxiv.org/abs/1708.03629)
"""
function pca_reduction(wv::WordVectors{S,T,H},
rdim::Int=7,
outdim::Int=size(wv.vectors,1);
do_pca::Bool=true
) where {S<:AbstractString, T<:Real, H<:Integer}

# Perform first post-processing
X = _pca_postprocessing(wv.vectors, rdim)

# Do PCA and post-process again
if do_pca
outdim = clamp(outdim, 1, size(X,1))
pratio = ifelse(size(wv.vectors,1)==outdim, 1.0, 0.99)
M = fit(PCA, X, maxoutdim=outdim, pratio=pratio)
X = transform(M, X)
X = _pca_postprocessing(X, rdim)
end

return WordVectors{S,T,H}(wv.vocab, X, wv.vocab_hash)
end


function _pca_postprocessing(X::AbstractMatrix{T}, rdim::Int=7) where {T<:AbstractFloat}
# Subtract the mean
@debug "Subtracting the mean..."
X .-= mean(X, dims=2)

# Compute the first d PCA components
@debug "Computing PCA components..."
m, n = size(X)
rdim = clamp(rdim, 1, m)
M = fit(PCA, X, pratio=1.0, mean=0)
M = __handle_pca_dimensions(M, m)
Xd = transform(M, X)
Xdv = [Xd[:,i]*Xd[:,i]' for i in 1:rdim]

# Eliminate top d components
@debug "Eliminating the top $rdim components..."
Xout = similar(X)
@inbounds @simd for i in 1:n
Xout[:,i] = X[:,i] .- mapreduce(x->x*X[:,i], +, Xdv)
end
return Xout
end


# Introduces 0 components into PCA transform to force the number
# of components be equal to the number of dimensions explicitly
# specified
function __handle_pca_dimensions(M::PCA{T}, m) where {T}
pcadim = length(M.prinvars)
if pcadim < m
proj = zeros(T, m, m)
proj[:, 1:pcadim] .+= M.proj
prinvars = zeros(T, m)
prinvars[1:pcadim] .+= M.prinvars
M = PCA(M.mean, proj, prinvars, M.tprinvar, M.tvar)
end
return M
end
11 changes: 11 additions & 0 deletions test/pca_reduction.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
@testset "PCA reduction" begin
wv = fake_wordvectors()
for d in [1, 10]
for outdim in [1, 10]
for do_pca in [false, true]
wv_pca = pca_reduction(wv, d, outdim, do_pca=do_pca)
@test wv isa typeof(wv)
end
end
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,6 @@ using ConceptnetNumberbatch
include("dump.jl")
include("conceptnet2wv.jl")
include("similarity_order.jl")
include("pca_reduction.jl")

end

0 comments on commit 3efa5ec

Please sign in to comment.